diff --git a/.golangci.yaml b/.golangci.yaml index 89e0440e..99fc67a9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -7,6 +7,7 @@ run: linters: default: all disable: + - gomodguard # Replaced by gomodguard_v2, enabled by default: all - nilerr - tagliatelle - bodyclose diff --git a/Makefile b/Makefile index a05add0c..e8b7b491 100644 --- a/Makefile +++ b/Makefile @@ -106,37 +106,41 @@ docker-dev-build: .PHONY: codegen codegen: go generate ./... - go run github.com/sqlc-dev/sqlc/cmd/sqlc@latest generate + go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate .PHONY: clean clean: - rm -f cover.out cover.html session-manager - rm -rf cover/ + @rm -f cover.out cover.html session-manager + @rm -rf cover/ + +.PHONY: fix-lint +fix-lint: + golangci-lint run --fix --build-tags=integration ./... .PHONY: lint lint: - golangci-lint run ./... + golangci-lint run --build-tags=integration ./... .PHONY: build build: go build ./cmd/session-manager .PHONY: test -test: clean install-gotestsum +test: clean @mkdir -p cover/integration cover/unit @go clean -testcache - gotestsum --junitfile="${CURDIR}/junit-unit.xml" --format=testname -- -count=1 -race -cover ./... -args -test.gocoverdir="${CURDIR}/cover/unit" - GOCOVERDIR="${CURDIR}/cover/integration" gotestsum --junitfile="${CURDIR}/junit-integration.xml" --format=testname -- -v -count=1 -race -tags=integration ./integration + @go tool gotest.tools/gotestsum --junitfile="${CURDIR}/junit-unit.xml" --format=dots-v2 -- -count=1 -race -cover ./... -args -test.gocoverdir="${CURDIR}/cover/unit" + @GOCOVERDIR=${CURDIR}/cover/integration go tool gotest.tools/gotestsum --junitfile="${CURDIR}/junit-integration.xml" --format=dots-v2 -- -v -count=1 -race -tags=integration ./integration @go tool covdata textfmt -i=./cover/unit,./cover/integration -o cover.out @grep -v 'github.com/openkcm/session-manager/internal/openapi/' cover.out > cover.tmp && mv cover.tmp cover.out @grep -v 'github.com/openkcm/session-manager/internal/dbtest/' cover.out > cover.tmp && mv cover.tmp cover.out - @grep -v 'github.com/openkcm/session-manager/internal/trust/trustmock/' cover.out > cover.tmp && mv cover.tmp cover.out + @grep -v 'github.com/openkcm/session-manager/modules/oidctrust/mocks/' cover.out > cover.tmp && mv cover.tmp cover.out @grep -v 'github.com/openkcm/session-manager/internal/session/mock/' cover.out > cover.tmp && mv cover.tmp cover.out @go tool cover -func=cover.out - @echo "On a Mac, you can use the following command to open the coverage report in the browser\ngo tool cover -html=cover.out -o cover.html && open cover.html" + @echo "On a Mac, you can use the following command to open the coverage report in the browser\ngo tool cover -html=cover.out" .PHONY: helm-test helm-test: @@ -172,10 +176,6 @@ helm-integration-test-run: k3d-teardown: k3d cluster delete $(K3D_CLUSTER_NAME) -.PHONY: install-gotestsum -install-gotestsum: - (cd /tmp && go install gotest.tools/gotestsum@latest) - .PHONY: image image: docker build -t ${IMG} . diff --git a/Taskfile.yaml b/Taskfile.yaml index 5bafdc5f..bfaf9b3a 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -6,6 +6,6 @@ includes: flatten: true excludes: [] # put task names in here which are overwritten in this file vars: - CODE_DIRS: "{{.ROOT_DIR}}/cmd/... {{.ROOT_DIR}}/internal/... {{.ROOT_DIR}}/integration/... {{.ROOT_DIR}}/sql/..." + CODE_DIRS: "{{.ROOT_DIR}}/cmd/... {{.ROOT_DIR}}/internal/... {{.ROOT_DIR}}/integration/... {{.ROOT_DIR}} {{.ROOT_DIR}}/modules/..." COMPONENTS: session-manager REPO_URL: https://github.com/openkcm/session-manager diff --git a/charts/session-manager/templates/configmap.yaml b/charts/session-manager/templates/configmap.yaml index acdc66fd..2ed3ad58 100644 --- a/charts/session-manager/templates/configmap.yaml +++ b/charts/session-manager/templates/configmap.yaml @@ -43,6 +43,12 @@ data: database: {{- toYaml .database | nindent 6 }} + trust: + {{- toYaml .trust | nindent 6 }} + + credentials: + {{- toYaml .credentials | nindent 6 }} + valkey: {{- toYaml .valkey | nindent 6 }} @@ -54,4 +60,7 @@ data: housekeeper: {{- toYaml .housekeeper | nindent 6 }} + + apps: + {{- toYaml .apps | nindent 6 }} {{- end }} diff --git a/charts/session-manager/values-dev.yaml b/charts/session-manager/values-dev.yaml index d9d30d39..68c9b199 100644 --- a/charts/session-manager/values-dev.yaml +++ b/charts/session-manager/values-dev.yaml @@ -272,6 +272,7 @@ config: enabled: false database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -284,7 +285,14 @@ config: source: embedded value: secret + trust: + module: trust.module.oidc + + credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: valkey-headless.session-manager.svc.cluster.local:6379 @@ -296,7 +304,7 @@ config: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded @@ -335,9 +343,18 @@ config: value: my-csrf-secret-at-least-thirty-two-bits-size migrate: + module: trust.migration.module.oidc source: file:///sql housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + + apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/charts/session-manager/values.yaml b/charts/session-manager/values.yaml index 94666994..2683a740 100644 --- a/charts/session-manager/values.yaml +++ b/charts/session-manager/values.yaml @@ -280,6 +280,7 @@ config: enabled: false database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -292,7 +293,14 @@ config: source: embedded value: secret + trust: + module: trust.module.oidc + + credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: host.ns.svc.cluster.local @@ -304,7 +312,7 @@ config: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded @@ -348,9 +356,18 @@ config: value: my-csrf-secret-at-least-thirty-two-bits-size migrate: + module: trust.migration.module.oidc source: file:///sql housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + + apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/cmd/session-manager/main.go b/cmd/session-manager/main.go index e23cc5be..045cf631 100644 --- a/cmd/session-manager/main.go +++ b/cmd/session-manager/main.go @@ -1,87 +1,14 @@ package main import ( - "context" - "log/slog" - "os" - "os/signal" - "time" - - "github.com/openkcm/common-sdk/pkg/utils" - "github.com/spf13/cobra" - - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/cmd/session-manager/apiserver" - "github.com/openkcm/session-manager/cmd/session-manager/housekeeper" - "github.com/openkcm/session-manager/cmd/session-manager/migrate" -) - -var ( - // BuildInfo will be set by the build system - BuildInfo = "{}" - - isVersionCmd bool - gracefulShutdown time.Duration + "github.com/openkcm/session-manager/cmd/session-manager/maincmd" + _ "github.com/openkcm/session-manager/modules/standard" ) -var versionCmd = &cobra.Command{ - Use: "version", - Short: "Session Manager Version", - RunE: func(cmd *cobra.Command, _ []string) error { - isVersionCmd = true - - value, err := utils.ExtractFromComplexValue(BuildInfo) - if err != nil { - return err - } - - slog.InfoContext(cmd.Context(), value) - - return nil - }, -} - -func rootCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "session-manager", - Short: "Session Manager", - Long: "KCM Session Manager, implementing the OIDC authorization code flow.", - } - - cmd.PersistentFlags().DurationVar(&gracefulShutdown, "graceful-shutdown", 1*time.Second, "graceful shutdown") - - cmd.AddCommand( - versionCmd, - apiserver.Cmd(BuildInfo), - housekeeper.Cmd(BuildInfo), - migrate.Cmd(BuildInfo), - ) - - return cmd -} - -func execute() error { - ctx, cancelOnSignal := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) - defer cancelOnSignal() - - err := rootCmd().ExecuteContext(ctx) - if err != nil { - slogctx.Error(ctx, "failed to start the application", "error", err) - return err - } - - if !isVersionCmd { - slogctx.Info(ctx, "Graceful shutdown", "duration", gracefulShutdown) - time.Sleep(gracefulShutdown) - } - - return nil -} +// BuildInfo will be set by the build system +var BuildInfo = "{}" func main() { - err := execute() - if err != nil { - os.Exit(1) - } + maincmd.BuildInfo = BuildInfo + maincmd.Main() } diff --git a/cmd/session-manager/maincmd/main.go b/cmd/session-manager/maincmd/main.go new file mode 100644 index 00000000..4b7d197b --- /dev/null +++ b/cmd/session-manager/maincmd/main.go @@ -0,0 +1,87 @@ +package maincmd + +import ( + "context" + "log/slog" + "os" + "os/signal" + "time" + + "github.com/openkcm/common-sdk/pkg/utils" + "github.com/spf13/cobra" + + slogctx "github.com/veqryn/slog-context" + + "github.com/openkcm/session-manager/cmd/session-manager/apiserver" + "github.com/openkcm/session-manager/cmd/session-manager/housekeeper" + "github.com/openkcm/session-manager/cmd/session-manager/migrate" +) + +var ( + // BuildInfo will be set by the build system + BuildInfo = "{}" + + isVersionCmd bool + gracefulShutdown time.Duration +) + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Session Manager Version", + RunE: func(cmd *cobra.Command, _ []string) error { + isVersionCmd = true + + value, err := utils.ExtractFromComplexValue(BuildInfo) + if err != nil { + return err + } + + slog.InfoContext(cmd.Context(), value) + + return nil + }, +} + +func rootCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "session-manager", + Short: "Session Manager", + Long: "KCM Session Manager, implementing the OIDC authorization code flow.", + } + + cmd.PersistentFlags().DurationVar(&gracefulShutdown, "graceful-shutdown", 1*time.Second, "graceful shutdown") + + cmd.AddCommand( + versionCmd, + apiserver.Cmd(BuildInfo), + housekeeper.Cmd(BuildInfo), + migrate.Cmd(BuildInfo), + ) + + return cmd +} + +func execute() error { + ctx, cancelOnSignal := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancelOnSignal() + + err := rootCmd().ExecuteContext(ctx) + if err != nil { + slogctx.Error(ctx, "failed to start the application", "error", err) + return err + } + + if !isVersionCmd { + slogctx.Info(ctx, "Graceful shutdown", "duration", gracefulShutdown) + time.Sleep(gracefulShutdown) + } + + return nil +} + +func Main() { + err := execute() + if err != nil { + os.Exit(1) + } +} diff --git a/config.yaml b/config.yaml index b289f853..4b03f8b7 100644 --- a/config.yaml +++ b/config.yaml @@ -181,6 +181,7 @@ telemetry: jsonPath: "$.server-ca" database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -193,7 +194,18 @@ database: source: embedded value: secret +trust: + module: trust.module.oidc + +migrate: + module: trust.migration.module.oidc + source: file://./sql + +credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: localhost:6379 @@ -205,7 +217,7 @@ valkey: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded @@ -248,10 +260,15 @@ sessionManager: source: embedded value: my-csrf-secret-at-least-thirty-two-bits-size -migrate: - source: file://./sql - housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + +apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/context.go b/context.go new file mode 100644 index 00000000..825ec175 --- /dev/null +++ b/context.go @@ -0,0 +1,187 @@ +package sessionmanager + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + "slices" + + slogctx "github.com/veqryn/slog-context" +) + +type Context struct { + //nolint:containedctx + context.Context + + mods map[string]Module + modOrder []string + apps map[string]App +} + +func (c *Context) cloneWithParent(parent context.Context) *Context { + return &Context{ + Context: parent, + mods: c.mods, + modOrder: c.modOrder, + apps: c.apps, + } +} + +func (c *Context) WithValue(key, val any) *Context { + return c.cloneWithParent(context.WithValue(c.Context, key, val)) +} + +func NewContext(ctx context.Context) (*Context, context.CancelCauseFunc) { + ctx, cancelCause := context.WithCancelCause(ctx) + c := &Context{ + Context: ctx, + mods: make(map[string]Module), + modOrder: nil, + apps: make(map[string]App), + } + return c, func(cause error) { + cancelCause(cause) + for name, app := range c.apps { + if closer, ok := app.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close an app", "app", name, "error", err) + } + } + } + for _, v := range slices.Backward(c.modOrder) { + id := v + mod, ok := c.mods[id] + if !ok { + continue + } + if closer, ok := mod.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close a module", "module", id, "error", err) + } + } + } + } +} + +type ExtensionConfig interface { + Module() string + UnmarshalExtension(into Module) error +} + +func (c *Context) GetModule(id string) (Module, error) { + if mod, ok := c.mods[id]; ok { + return mod, nil + } + + return nil, errors.New("module is not loaded") +} + +func (c *Context) GetApp(id string) (App, error) { + if app, ok := c.apps[id]; ok { + return app, nil + } + + return nil, errors.New("app is not loaded") +} + +func (c *Context) LoadModule(cfg ExtensionConfig) (Module, error) { + before := len(c.modOrder) + mod, modInfo, err := c.instantiate(cfg) + if err != nil { + c.unloadModulesAfter(before) + return nil, err + } + + if _, ok := c.mods[modInfo.ID]; ok { + c.unloadModulesAfter(before) + return nil, errors.New("module has already been loaded") + } + + c.mods[modInfo.ID] = mod + c.modOrder = append(c.modOrder, modInfo.ID) + + return mod, nil +} + +func (c *Context) LoadApp(cfg ExtensionConfig) (App, error) { + before := len(c.modOrder) + mod, modInfo, err := c.instantiate(cfg) + if err != nil { + c.unloadModulesAfter(before) + return nil, err + } + + app, ok := mod.(App) + if !ok { + c.unloadModulesAfter(before) + return nil, fmt.Errorf("module %q does not implement the App interface", modInfo.ID) + } + + if _, ok := c.apps[modInfo.ID]; ok { + c.unloadModulesAfter(before) + return nil, errors.New("app has already been loaded") + } + + c.apps[modInfo.ID] = app + + return app, nil +} + +// unloadModulesAfter rolls back any modules appended to modOrder at or after +// the snapshot index. Modules are closed in reverse load order. It is the +// recovery path for a failed LoadModule or LoadApp call: every successfully +// loaded child module is closed and removed from the registry before the +// error surfaces to the caller. +func (c *Context) unloadModulesAfter(snapshot int) { + if snapshot >= len(c.modOrder) { + return + } + for i := len(c.modOrder) - 1; i >= snapshot; i-- { + id := c.modOrder[i] + mod, ok := c.mods[id] + if !ok { + continue + } + if closer, ok := mod.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close a module during rollback", "module", id, "error", err) + } + } + delete(c.mods, id) + } + c.modOrder = c.modOrder[:snapshot] +} + +// instantiate resolves cfg.Module(), calls New(), unmarshals the extension, and +// runs Provision if the resulting instance is a Provisioner. It is shared by +// LoadModule and LoadApp. +func (c *Context) instantiate(cfg ExtensionConfig) (Module, ModuleInfo, error) { + modInfo, err := GetModule(cfg.Module()) + if err != nil { + return nil, ModuleInfo{}, fmt.Errorf("getting module %q: %w", reflect.TypeOf(cfg), err) + } + + slogctx.Debug(c, "loading module", "module", modInfo.ID) + + mod := modInfo.New() + rv := reflect.ValueOf(mod) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Struct { + if err := cfg.UnmarshalExtension(mod); err != nil { + return nil, ModuleInfo{}, fmt.Errorf("unmarshaling extension %s: %w", modInfo.ID, err) + } + } + + slogctx.Debug(c, "instantiated a module", "module", modInfo.ID) + + if provisioner, ok := mod.(Provisioner); ok { + if err := provisioner.Provision(c); err != nil { + return nil, ModuleInfo{}, fmt.Errorf("provisioning module: %w", err) + } + + slogctx.Debug(c, "provisioned module", "module", modInfo.ID) + } + + return mod, modInfo, nil +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 00000000..4c263335 --- /dev/null +++ b/context_test.go @@ -0,0 +1,602 @@ +package sessionmanager_test + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// provisionableModule records whether Provision was called. +type provisionableModule struct { + stubModule + + provisioned bool +} + +func (m *provisionableModule) Provision(_ *sessionmanager.Context) error { + m.provisioned = true + return nil +} + +// failingProvisionerModule always returns an error from Provision. +type failingProvisionerModule struct{ stubModule } + +func (m *failingProvisionerModule) Provision(_ *sessionmanager.Context) error { + return errors.New("provision failed") +} + +// closableModule records whether Close was called. +type closableModule struct { + stubModule + + closed bool +} + +func (m *closableModule) Close() error { + m.closed = true + return nil +} + +// closeErrModule returns an error from Close (exercises the error-log path). +type closeErrModule struct{ stubModule } + +func (m *closeErrModule) Close() error { return errors.New("close error") } + +// simpleExtensionConfig is a minimal ExtensionConfig that references a registered module. +type simpleExtensionConfig struct{ moduleID string } + +func (c *simpleExtensionConfig) Module() string { return c.moduleID } +func (c *simpleExtensionConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } + +// failingUnmarshalConfig returns an error from UnmarshalExtension. +type failingUnmarshalConfig struct{ moduleID string } + +func (c *failingUnmarshalConfig) Module() string { return c.moduleID } +func (c *failingUnmarshalConfig) UnmarshalExtension(_ sessionmanager.Module) error { + return errors.New("unmarshal failed") +} + +// customNewModule registers a module whose New() function delegates to newFn. +type customNewModule struct { + id string + newFn func() sessionmanager.Module +} + +func (m *customNewModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: m.id, New: m.newFn} +} + +func TestNewContext_CancelCloseModules(t *testing.T) { + id := uniqueID(t, "closable") + cm := &closableModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return cm }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + cancel(nil) + assert.True(t, cm.closed, "Close() should be called when context is cancelled") +} + +func TestNewContext_CancelWithCause(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + + cause := errors.New("test cause") + cancel(cause) + + assert.ErrorIs(t, context.Cause(ctx), cause) +} + +func TestNewContext_CloseErrorIsHandled(t *testing.T) { + id := uniqueID(t, "closeerr") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &closeErrModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + // Should not panic even though Close() returns an error. + assert.NotPanics(t, func() { cancel(nil) }) +} + +func TestContext_WithValue(t *testing.T) { + type ctxKey struct{} + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + ctx2 := ctx.WithValue(ctxKey{}, "hello") + assert.Equal(t, "hello", ctx2.Value(ctxKey{})) + // Original context should not carry the value. + assert.Nil(t, ctx.Value(ctxKey{})) +} + +func TestLoadModule_Success(t *testing.T) { + id := uniqueID(t, "prov") + pm := &provisionableModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return pm }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + mod, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + require.NotNil(t, mod) + assert.True(t, pm.provisioned) +} + +func TestLoadModule_UnknownModule(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: "no-such-module"}) + require.Error(t, err) +} + +func TestLoadModule_DuplicateReturnsError(t *testing.T) { + id := uniqueID(t, "dup") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + _, err = ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "already been loaded") +} + +func TestLoadModule_ProvisionError(t *testing.T) { + id := uniqueID(t, "failprov") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &failingProvisionerModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "provision failed") +} + +func TestLoadModule_UnmarshalError(t *testing.T) { + id := uniqueID(t, "unmarshalerr") + // Use a pointer-to-struct module so the unmarshal branch is reached. + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&failingUnmarshalConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unmarshal failed") +} + +func TestGetModule_AfterLoad(t *testing.T) { + id := uniqueID(t, "get") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + mod, err := ctx.GetModule(id) + require.NoError(t, err) + assert.NotNil(t, mod) +} + +func TestGetModule_NotLoaded(t *testing.T) { + id := uniqueID(t, "notloaded") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + // Never call LoadModule — GetModule should return an error. + _, err := ctx.GetModule(id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not loaded") +} + +// appModule is a Module that also satisfies the App interface. +type appModule struct { + stubModule + + started bool + stopped bool +} + +func (a *appModule) Start() error { + a.started = true + return nil +} + +func (a *appModule) Stop() error { + a.stopped = true + return nil +} + +// closableAppModule is an App that also satisfies io.Closer. +type closableAppModule struct { + appModule + + closed bool +} + +func (a *closableAppModule) Close() error { + a.closed = true + return nil +} + +func TestLoadApp_Success(t *testing.T) { + id := uniqueID(t, "app") + am := &appModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return am }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + app, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + require.NotNil(t, app) + + got, err := ctx.GetApp(id) + require.NoError(t, err) + assert.Same(t, app, got) +} + +func TestLoadApp_MissingAppInterface(t *testing.T) { + id := uniqueID(t, "notapp") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "App interface") +} + +func TestLoadApp_UnknownModule(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: "no-such-app-module"}) + require.Error(t, err) +} + +func TestLoadApp_DuplicateReturnsError(t *testing.T) { + id := uniqueID(t, "dupapp") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &appModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + _, err = ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "already been loaded") +} + +func TestGetApp_NotLoaded(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.GetApp("never-loaded") + require.Error(t, err) + assert.Contains(t, err.Error(), "not loaded") +} + +func TestNewContext_CancelClosesApps(t *testing.T) { + id := uniqueID(t, "closableapp") + cam := &closableAppModule{appModule: appModule{stubModule: stubModule{id: id}}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return cam }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + cancel(nil) + assert.True(t, cam.closed, "Close() should be called on apps when context is cancelled") +} + +// childLoadingProvisioner is a Module whose Provision loads the configured +// child module IDs, in order, via ctx.LoadModule. If failAfter is non-negative +// it returns an error immediately after loading that many children. +type childLoadingProvisioner struct { + stubModule + + childIDs []string + failAfter int // -1 = never fail + failReason string // error text used when failAfter triggers +} + +func (m *childLoadingProvisioner) Provision(ctx *sessionmanager.Context) error { + for i, id := range m.childIDs { + if m.failAfter >= 0 && i == m.failAfter { + return errors.New(m.failReason) + } + if _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}); err != nil { + return err + } + } + if m.failAfter >= 0 && m.failAfter >= len(m.childIDs) { + return errors.New(m.failReason) + } + return nil +} + +// childLoadingApp is an App whose Provision loads the given child modules +// before the framework registers it as an app. +type childLoadingApp struct { + appModule + childLoadingProvisioner +} + +func (a *childLoadingApp) Module() sessionmanager.ModuleInfo { + return a.appModule.Module() +} + +func (a *childLoadingApp) Provision(ctx *sessionmanager.Context) error { + return a.childLoadingProvisioner.Provision(ctx) +} + +// orderRecorder is shared across closableOrderModule instances so tests can +// observe Close ordering across the framework's reverse-load-order shutdown. +type orderRecorder struct { + closes []string +} + +// closableOrderModule appends its ID to the recorder when Close is called. +type closableOrderModule struct { + stubModule + + rec *orderRecorder +} + +func (m *closableOrderModule) Close() error { + m.rec.closes = append(m.rec.closes, m.id) + return nil +} + +func TestLoadModule_ChildLoadFailureRollsBackEarlierSiblings(t *testing.T) { + parentID := uniqueID(t, "parent") + child1ID := uniqueID(t, "child1") + // child2ID is intentionally unregistered so its load fails. + child2ID := uniqueID(t, "child2-missing") + + c1 := &closableModule{stubModule: stubModule{id: child1ID}} + sessionmanager.RegisterModule(&customNewModule{ + id: child1ID, + newFn: func() sessionmanager.Module { return c1 }, + }) + + parent := &childLoadingProvisioner{ + stubModule: stubModule{id: parentID}, + childIDs: []string{child1ID, child2ID}, + failAfter: -1, + } + sessionmanager.RegisterModule(&customNewModule{ + id: parentID, + newFn: func() sessionmanager.Module { return parent }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: parentID}) + require.Error(t, err) + assert.True(t, c1.closed, "earlier sibling must be closed during rollback") + + // child1 must be removed from the registry. + _, err = ctx.GetModule(child1ID) + require.Error(t, err) + + // parent itself was never registered (its Provision failed). + _, err = ctx.GetModule(parentID) + require.Error(t, err) +} + +func TestLoadApp_ProvisionErrorRollsBackChildren(t *testing.T) { + appID := uniqueID(t, "app") + child1ID := uniqueID(t, "ch1") + child2ID := uniqueID(t, "ch2") + + c1 := &closableModule{stubModule: stubModule{id: child1ID}} + c2 := &closableModule{stubModule: stubModule{id: child2ID}} + + sessionmanager.RegisterModule(&customNewModule{ + id: child1ID, + newFn: func() sessionmanager.Module { return c1 }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: child2ID, + newFn: func() sessionmanager.Module { return c2 }, + }) + + app := &childLoadingApp{ + appModule: appModule{stubModule: stubModule{id: appID}}, + childLoadingProvisioner: childLoadingProvisioner{ + childIDs: []string{child1ID, child2ID}, + failAfter: 2, // fail after both children loaded + failReason: "app provision boom", + }, + } + sessionmanager.RegisterModule(&customNewModule{ + id: appID, + newFn: func() sessionmanager.Module { return app }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: appID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "app provision boom") + + assert.True(t, c1.closed, "child1 must be closed during rollback") + assert.True(t, c2.closed, "child2 must be closed during rollback") + + // Neither child is in the registry. + _, err = ctx.GetModule(child1ID) + require.Error(t, err) + _, err = ctx.GetModule(child2ID) + require.Error(t, err) + + // The app itself was never registered. + _, err = ctx.GetApp(appID) + require.Error(t, err) +} + +func TestLoadModule_NonCloserChildIsRemovedOnRollback(t *testing.T) { + parentID := uniqueID(t, "parent-noncloser") + plainID := uniqueID(t, "plain-child") + missingID := uniqueID(t, "missing-child") + + sessionmanager.RegisterModule(&customNewModule{ + id: plainID, + newFn: func() sessionmanager.Module { return &stubModule{id: plainID} }, + }) + + parent := &childLoadingProvisioner{ + stubModule: stubModule{id: parentID}, + childIDs: []string{plainID, missingID}, + failAfter: -1, + } + sessionmanager.RegisterModule(&customNewModule{ + id: parentID, + newFn: func() sessionmanager.Module { return parent }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: parentID}) + require.Error(t, err) + + // Non-closer child must still be removed from the registry. + _, err = ctx.GetModule(plainID) + require.Error(t, err) +} + +func TestNewContext_CloseInReverseLoadOrder(t *testing.T) { + rec := &orderRecorder{} + + idA := uniqueID(t, "ord-A") + idB := uniqueID(t, "ord-B") + idC := uniqueID(t, "ord-C") + + for _, id := range []string{idA, idB, idC} { + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: id}, rec: rec} }, + }) + } + + ctx, cancel := sessionmanager.NewContext(t.Context()) + + for _, id := range []string{idA, idB, idC} { + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + } + + cancel(nil) + + require.Equal(t, []string{idC, idB, idA}, rec.closes, + "modules must be closed in reverse load order") +} + +func TestNewContext_CloseSkipsNonClosersInReverseOrder(t *testing.T) { + rec := &orderRecorder{} + + idA := uniqueID(t, "mix-A") // closer + idB := uniqueID(t, "mix-B") // non-closer + idC := uniqueID(t, "mix-C") // closer + + sessionmanager.RegisterModule(&customNewModule{ + id: idA, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: idA}, rec: rec} }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: idB, + newFn: func() sessionmanager.Module { return &stubModule{id: idB} }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: idC, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: idC}, rec: rec} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + for _, id := range []string{idA, idB, idC} { + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + } + + cancel(nil) + require.Equal(t, []string{idC, idA}, rec.closes, + "only Closer modules must be closed, in reverse load order") +} + +// Ensure stubModule satisfies the Module interface at compile time. +var _ sessionmanager.Module = (*stubModule)(nil) + +// Ensure provisionableModule satisfies Provisioner at compile time. +var _ sessionmanager.Provisioner = (*provisionableModule)(nil) + +// Ensure closableModule satisfies io.Closer at compile time. +var _ io.Closer = (*closableModule)(nil) + +// Ensure appModule satisfies App at compile time. +var _ sessionmanager.App = (*appModule)(nil) diff --git a/database.go b/database.go new file mode 100644 index 00000000..6ce95787 --- /dev/null +++ b/database.go @@ -0,0 +1,20 @@ +package sessionmanager + +import ( + "context" + "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type Database interface { + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + STDAdapter() *sql.DB +} + +type Migrate interface { + Migrate(ctx context.Context) error +} diff --git a/go.mod b/go.mod index 3cf7c127..0264d926 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,15 @@ module github.com/openkcm/session-manager -go 1.26.0 - -toolchain go1.26.2 +go 1.26.3 tool ( github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen github.com/sqlc-dev/sqlc/cmd/sqlc + gotest.tools/gotestsum ) require ( - github.com/XSAM/otelsql v0.42.0 + github.com/creasty/defaults v1.8.0 github.com/exaring/otelpgx v0.10.0 github.com/go-jose/go-jose/v4 v4.1.4 github.com/go-viper/mapstructure/v2 v2.5.0 @@ -19,9 +18,11 @@ require ( github.com/google/go-cmp v0.7.0 github.com/jackc/pgx/v5 v5.9.2 github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/knadh/koanf/providers/file v1.2.1 + github.com/knadh/koanf/v2 v2.3.4 github.com/moby/moby/api v1.54.2 github.com/oapi-codegen/runtime v1.4.0 - github.com/openkcm/api-sdk v0.17.0 + github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf github.com/openkcm/common-sdk v1.16.0 github.com/pressly/goose/v3 v3.27.1 github.com/samber/oops v1.21.0 @@ -35,6 +36,7 @@ require ( go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 google.golang.org/grpc v1.81.1 + google.golang.org/protobuf v1.36.11 ) require ( @@ -44,9 +46,11 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Dynatrace/OneAgent-SDK-for-Go v1.1.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/XSAM/otelsql v0.42.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bitfield/gotestdox v0.2.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -56,10 +60,10 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect - github.com/creasty/defaults v1.8.0 // indirect github.com/cubicdaiya/gonp v1.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect + github.com/dnephin/pflag v1.0.7 // indirect github.com/docker/go-connections v0.7.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect @@ -67,6 +71,7 @@ require ( github.com/ebitengine/purego v0.10.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/fatih/structtag v1.2.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.10.1 // indirect @@ -79,8 +84,9 @@ require ( github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/google/cel-go v0.26.1 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 // indirect github.com/hashicorp/go-version v1.9.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -89,13 +95,17 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.5 // indirect - github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 // indirect + github.com/klauspost/compress v1.18.6 // indirect + github.com/knadh/koanf/maps v0.1.2 // indirect + github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.21 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mfridman/interpolate v0.0.2 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.2.0 // indirect github.com/moby/moby/client v0.4.1 // indirect @@ -116,7 +126,7 @@ require ( github.com/oliveagle/jsonpath v0.1.4 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pelletier/go-toml/v2 v2.3.1 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pganalyze/pg_query_go/v6 v6.1.0 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect @@ -135,11 +145,11 @@ require ( github.com/riza-io/grpc-go v0.2.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/samber/lo v1.53.0 // indirect - github.com/samber/slog-common v0.21.0 // indirect + github.com/samber/slog-common v0.22.0 // indirect github.com/samber/slog-formatter v1.3.0 // indirect github.com/samber/slog-multi v1.8.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect - github.com/shirou/gopsutil/v4 v4.26.3 // indirect + github.com/shirou/gopsutil/v4 v4.26.4 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/speakeasy-api/jsonpath v0.6.0 // indirect github.com/speakeasy-api/openapi-overlay v0.10.2 // indirect @@ -152,16 +162,16 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/testcontainers/testcontainers-go v0.42.0 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect - github.com/tklauser/go-sysconf v0.3.16 // indirect - github.com/tklauser/numcpus v0.11.0 // indirect + github.com/tklauser/go-sysconf v0.4.0 // indirect + github.com/tklauser/numcpus v0.12.0 // indirect github.com/veqryn/slog-context/otel v0.9.0 // indirect github.com/vmware-labs/yaml-jsonpath v0.3.2 // indirect github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 // indirect github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/collector/featuregate v1.57.0 // indirect - go.opentelemetry.io/collector/pdata v1.57.0 // indirect + go.opentelemetry.io/collector/featuregate v1.58.0 // indirect + go.opentelemetry.io/collector/pdata v1.58.0 // indirect go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect @@ -184,21 +194,22 @@ require ( go.uber.org/zap v1.27.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.50.0 // indirect + golang.org/x/crypto v0.51.0 // indirect golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect golang.org/x/mod v0.35.0 // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/net v0.54.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/term v0.43.0 // indirect + golang.org/x/text v0.37.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.44.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect - google.golang.org/protobuf v1.36.11 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/gotestsum v1.13.0 // indirect modernc.org/libc v1.72.1 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index 61eac509..1148fbbd 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= +github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= @@ -57,6 +59,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= +github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf0mN39c= github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -74,6 +78,8 @@ github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMD github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0= github.com/exaring/otelpgx v0.10.0 h1:NGGegdoBQM3jNZDKG8ENhigUcgBN7d7943L0YlcIpZc= github.com/exaring/otelpgx v0.10.0/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -136,10 +142,12 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20260402051712-545e8a4df936 h1:EwtI+Al+DeppwYX2oXJCETMO23COyaKGP6fHVpkpWpg= github.com/google/pprof v0.0.0-20260402051712-545e8a4df936/go.mod h1:MxpfABSjhmINe3F1It9d+8exIHFvUqtLIRCdOGNXqiI= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 h1:5VipnvEpbqr2gA2VbM+nYVbkIF28c5ZQfqCBQ5g2xfk= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0/go.mod h1:Hyl3n6Twe1hvtd9XUXDec4pTvgMSEixRuQKPTMH2bNs= github.com/hashicorp/go-version v1.9.0 h1:CeOIz6k+LoN3qX9Z0tyQrPtiB1DFYRPfCIBtaXPSCnA= github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -165,8 +173,14 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= -github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= -github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/knadh/koanf/maps v0.1.2 h1:RBfmAW5CnZT+PJ1CVc1QSJKf4Xu9kxfQgYVQSu8hpbo= +github.com/knadh/koanf/maps v0.1.2/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= +github.com/knadh/koanf/providers/file v1.2.1 h1:bEWbtQwYrA+W2DtdBrQWyXqJaJSG3KrP3AESOJYp9wM= +github.com/knadh/koanf/providers/file v1.2.1/go.mod h1:bp1PM5f83Q+TOUu10J/0ApLBd9uIzg+n9UgthfY+nRA= +github.com/knadh/koanf/v2 v2.3.4 h1:fnynNSDlujWE+v83hAp8wKr/cdoxHLO0629SN+U8Urc= +github.com/knadh/koanf/v2 v2.3.4/go.mod h1:gRb40VRAbd4iJMYYD5IxZ6hfuopFcXBpc9bbQpZwo28= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -178,18 +192,24 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 h1:PwQumkgq4/acIiZhtifTV5OUqqiP82UAl0h87xj/l9k= -github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= +github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e h1:Q6MvJtQK/iRcRtzAscm/zF23XxJlbECiGPyRicsX+Ak= +github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/go-archive v0.2.0 h1:zg5QDUM2mi0JIM9fdQZWC7U8+2ZfixfTYoHL7rWUcP8= @@ -252,13 +272,13 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/openkcm/api-sdk v0.17.0 h1:AwDrJvaEj6f35JAohwzILU3nMTeD19J3ycebq0weOMw= -github.com/openkcm/api-sdk v0.17.0/go.mod h1:ffjao8Qr0k9FbtYWmPWWHf52tfBec2CRN4IBgj+ncqo= +github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf h1:2aGrIcRODxQQ/6/E9e61IZ9SV4YlpHwPmk1hPAIhd70= +github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= github.com/openkcm/common-sdk v1.16.0 h1:pmLXRHvjqg+8ATEyzXarCRiRghw/8pXGn2OtoYuMEIU= github.com/openkcm/common-sdk v1.16.0/go.mod h1:4umveCyatAaTi6dSQgwaBg1O/wqHr4sjzuMIQhEuX1o= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= -github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= -github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= +github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pganalyze/pg_query_go/v6 v6.1.0 h1:jG5ZLhcVgL1FAw4C/0VNQaVmX1SUJx71wBGdtTtBvls= @@ -304,8 +324,8 @@ github.com/samber/lo v1.53.0 h1:t975lj2py4kJPQ6haz1QMgtId2gtmfktACxIXArw3HM= github.com/samber/lo v1.53.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/samber/oops v1.21.0 h1:18atcO4oEigNFuGXqr3NZWZ6P0XOSEXyBSAMXdQRxTc= github.com/samber/oops v1.21.0/go.mod h1:Hsm/sKPxtCfPh0w/cE3xVoRfSiE1joDRiStPAsmG9bo= -github.com/samber/slog-common v0.21.0 h1:Wo2hTly1Br5RjYqX/BTWJJeDnTE85oWk/7vqlpZuAUc= -github.com/samber/slog-common v0.21.0/go.mod h1:d/6OaSlzdkl9PFpfRLgn8FwY1OW6EFmPtBpsHX4MrU0= +github.com/samber/slog-common v0.22.0 h1:WyPxYRg/c5xUmxZJbtd0QgysHlLBhRA+MngKdJieHxE= +github.com/samber/slog-common v0.22.0/go.mod h1:d/6OaSlzdkl9PFpfRLgn8FwY1OW6EFmPtBpsHX4MrU0= github.com/samber/slog-formatter v1.3.0 h1:dpvLVSX883WSI222gUtEMLRd5ZOlKkmQlkrovdVZ9uA= github.com/samber/slog-formatter v1.3.0/go.mod h1:9y2j6qgrCpa7B5Kbv/sKp1ak7wJ91tsswp1BHOUSukc= github.com/samber/slog-multi v1.8.0 h1:E05c1wnQ+8M58oQDBABlJ4TEIJWssNgtckso3zlaLlI= @@ -314,8 +334,8 @@ github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc= -github.com/shirou/gopsutil/v4 v4.26.3/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= +github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY= +github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/speakeasy-api/jsonpath v0.6.0 h1:IhtFOV9EbXplhyRqsVhHoBmmYjblIRh5D1/g8DHMXJ8= @@ -357,10 +377,10 @@ github.com/testcontainers/testcontainers-go/modules/valkey v0.42.0 h1:SL15Mh/Jmo github.com/testcontainers/testcontainers-go/modules/valkey v0.42.0/go.mod h1:rKMKPmE5065l6Jk/HWu4D27cDysitXO8MQoWfwKMPg8= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= -github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= -github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= -github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= -github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/tklauser/go-sysconf v0.4.0 h1:7H0uAN+7RkwWRaxhYXDLqa5V3LPrJeV8wmD9dRUgPQU= +github.com/tklauser/go-sysconf v0.4.0/go.mod h1:8mTNWyog7H+MpKijp4VmKJAd2bbYQ2zuUwkYRbUArPI= +github.com/tklauser/numcpus v0.12.0 h1:NR85qdvHA9pFse3x3weVZ0r0ST8R6l5RHbZrlRaqob4= +github.com/tklauser/numcpus v0.12.0/go.mod h1:ABHeXzJnr/qqwguhClkZKT1/8VABcYrsyUiUGobwWJg= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/valkey-io/valkey-go v1.0.75 h1:cfq9DODW2ntuUgyHJmFWb4/p+xpLpQB1t5SQyWM9uJ4= @@ -380,12 +400,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/collector/featuregate v1.57.0 h1:KPDSUKYn6MHwgyGRSGPPcW/G96HH93pxuvvPwM+R8nY= -go.opentelemetry.io/collector/featuregate v1.57.0/go.mod h1:4ga1QBMPEejXXmpyJS8lmaRpknJ3Lb9Bvk6e420bUFU= -go.opentelemetry.io/collector/internal/testutil v0.151.0 h1:CFjDItLuqzblItOsnK6IPSdrsOaZCaDjYpB8qWG+XHI= -go.opentelemetry.io/collector/internal/testutil v0.151.0/go.mod h1:Jkjs6rkqs973LqgZ0Fe3zrokQRKULYXPIf4HuqStiEE= -go.opentelemetry.io/collector/pdata v1.57.0 h1:oDWBMjEIqyJO3GJEB+iwqxj47rxDK19OKzwaFEaE4sg= -go.opentelemetry.io/collector/pdata v1.57.0/go.mod h1:wZojinP6mNhLXudH8QXx/bjWzOsKMxi/FXwnk+12G/w= +go.opentelemetry.io/collector/featuregate v1.58.0 h1:Kh6Dpgbxywv/Q3D6qPehaSxNCxvr/U/ki7CL4y3udCo= +go.opentelemetry.io/collector/featuregate v1.58.0/go.mod h1:4ga1QBMPEejXXmpyJS8lmaRpknJ3Lb9Bvk6e420bUFU= +go.opentelemetry.io/collector/internal/testutil v0.152.0 h1:8LGwekR7mLcUDhT1ofLmdnrHRFuUa3U7PBd95ZvJEjQ= +go.opentelemetry.io/collector/internal/testutil v0.152.0/go.mod h1:Jkjs6rkqs973LqgZ0Fe3zrokQRKULYXPIf4HuqStiEE= +go.opentelemetry.io/collector/pdata v1.58.0 h1:5Lxut3NxKp87066Pzt+3q7+JUuFI5B3teCyLZIF8wIs= +go.opentelemetry.io/collector/pdata v1.58.0/go.mod h1:4vZtODINbC/JF3eGocnatdImzbRHseOywIcr+aULjCg= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 h1:hhPGP3zvvy1xWT9RTy970wlniSxFttBIsAK1gvMguJM= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0/go.mod h1:twJF7inoMza6kxMcF8JOdL3mPmtOZu7GEr34CUNE6Dg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo= @@ -456,8 +476,8 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -472,8 +492,8 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -497,18 +517,18 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= -golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -525,10 +545,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 h1:XF8+t6QQiS0o9ArVan/HW8Q7cycNPGsJf6GA2nXxYAg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60 h1:3WsB1FAbiRIf2tOxscWKs3pQBD9he1NsrnbhMuWfekc= +google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60/go.mod h1:7yoXV7RIh5gblj/xVYoogxAWvA9wUeVbpsK/M694l00= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 h1:seT2EwLWM78plQ7wcDfuWBc/4FAEAXDDiaSol4ku4qo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ= google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -565,6 +585,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/gotestsum v1.13.0 h1:+Lh454O9mu9AMG1APV4o0y7oDYKyik/3kBOiCqiEpRo= +gotest.tools/gotestsum v1.13.0/go.mod h1:7f0NS5hFb0dWr4NtcsAsF0y1kzjEFfAil0HiBQJE03Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= modernc.org/cc/v4 v4.28.1 h1:XpLbkYVQ24E8tX5u8+yWGvaxerxkR/S4zqxI8ZoSBuc= diff --git a/integration/api-server-status_test.go b/integration/api-server-status_test.go index b6389395..c3398735 100644 --- a/integration/api-server-status_test.go +++ b/integration/api-server-status_test.go @@ -55,8 +55,8 @@ func TestStatusServer(t *testing.T) { } // defer the graceful stop of the service so that coverprofiles are written defer func() { - cmd.Process.Signal(os.Interrupt) - cmd.Wait() + _ = cmd.Process.Signal(os.Interrupt) + _ = cmd.Wait() }() // create the test cases @@ -85,7 +85,12 @@ func TestStatusServer(t *testing.T) { if i < 1 { t.Fatalf("could not connect to server: %s", err) } - if _, err := http.Get("http://localhost:8888"); err == nil { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8888", nil) + if err != nil { + t.Fatal(err) + } + + if _, err := http.DefaultClient.Do(req); err == nil { break } time.Sleep(100 * time.Millisecond) @@ -100,7 +105,12 @@ func TestStatusServer(t *testing.T) { t.Fatalf("could not construct a request url: %s", err) } - resp, err := http.Get(u) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("could not send request: %s", err) } @@ -112,21 +122,14 @@ func TestStatusServer(t *testing.T) { // Assert if tc.wantError { - if err == nil { - t.Error("expected error, but got nil") - } if got != nil { t.Errorf("expected nil response, but got: %+v", got) } } else { - if err != nil { - t.Errorf("unexpected error: %s", err) - } else { - t.Logf("response: %s", got) - var js json.RawMessage - if json.Unmarshal([]byte(got), &js) != nil { - t.Errorf("response is not valid json: %s", got) - } + t.Logf("response: %s", got) + var js json.RawMessage + if json.Unmarshal(got, &js) != nil { + t.Errorf("response is not valid json: %s", got) } } }) diff --git a/integration/grpc_test.go b/integration/grpc_test.go index d80d275a..d40e7550 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -4,227 +4,297 @@ package integration_test import ( "context" + "errors" "fmt" - "net" + "os" + "os/exec" + "path/filepath" "testing" "time" "github.com/gofrs/uuid/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - slogctx "github.com/veqryn/slog-context" - stdgrpc "google.golang.org/grpc" - - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + healthpb "google.golang.org/grpc/health/grpc_health_v1" ) func TestGRPCServer(t *testing.T) { + const cmdName = "api-server" + const port = 9091 + // given ctx := t.Context() - port := 9091 - // create grpc server - srv, _, terminateFn, err := startServer(t, port) - require.NoError(t, err) - defer srv.Stop() - defer terminateFn(ctx) + istat := initInfra(t) + defer istat.Close(ctx) + + istat.Cfg.GRPC.Address = fmt.Sprintf(":%d", port) + + istat.PreparePostgres(t) + istat.PrepareValKey(t) + istat.PrepareConfig(t) + + currdir, err := os.Getwd() + require.NoError(t, err, "failed to get wd") + + t.Chdir(istat.Procdir) + + commandCtx, cancelCommand := context.WithCancel(ctx) + defer cancelCommand() + + cmd := exec.CommandContext(commandCtx, filepath.Join(currdir, "./session-manager"), cmdName) + cmd.WaitDelay = 5 * time.Second + cmd.Cancel = func() error { return cmd.Process.Signal(os.Interrupt) } + + cmdOutPath := filepath.Join(currdir, "grpc.log") + cmdOut, err := os.Create(cmdOutPath) + if err != nil { + t.Fatalf("failed to create an log file") + } + defer cmdOut.Close() + + cmd.Stdout = cmdOut + cmd.Stderr = cmdOut + + t.Logf("starting an app process. Logs will be saved into %s", cmdOutPath) + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start the server: %s", err) + } + + errCh := make(chan error) + go func() { + if err := cmd.Wait(); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("executing command: %w", err) + } + close(errCh) + }() // grpc client connection - conn, err := createClientConn(t, port) + cc, err := createClientConn(t, port) require.NoError(t, err) - defer conn.Close() + defer cc.Close() + + waitCtx, cancelWait := context.WithTimeout(commandCtx, 10*time.Second) + defer cancelWait() + if err := waitGRPCServerReady(waitCtx, cc); err != nil { + t.Fatalf("waiting for the server readiness: %s", err) + } - mappingClient := oidcmappingv1.NewServiceClient(conn) + trust := trustmappingv1.NewServiceClient(cc) - t.Run("ApplyOIDCMapping", func(t *testing.T) { + t.Run("ApplyTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) }) - t.Run("BlockOIDCMapping", func(t *testing.T) { + t.Run("BlockTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) - blockResp, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, blockResp.GetSuccess()) }) - t.Run("UnblockOIDCMapping", func(t *testing.T) { + t.Run("UnblockTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer1 := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer1, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer1, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) - unblockRes, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, unblockRes.GetSuccess()) }) - t.Run("RemoveOIDCMapping", func(t *testing.T) { + t.Run("RemoveTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - removeRes, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, removeRes.GetSuccess()) }) - t.Run("ApplyOIDCMapping with multiple audiences", func(t *testing.T) { + t.Run("ApplyTrustMapping with multiple audiences", func(t *testing.T) { expJwks := "jks-multi" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() audiences := []string{"aud1", "aud2", "aud3"} - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: audiences, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: audiences, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) }) - t.Run("ApplyOIDCMapping idempotent - applying same mapping twice", func(t *testing.T) { + t.Run("ApplyTrustMapping idempotent - applying same trust twice", func(t *testing.T) { expJwks := "jks-idempotent" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() // First application - applyResp1, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp1, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp1.GetSuccess()) // Second application (should be idempotent) - applyResp2, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp2, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp2.GetSuccess()) }) - t.Run("BlockOIDCMapping idempotent - blocking twice", func(t *testing.T) { + t.Run("BlockTrustMapping idempotent - blocking twice", func(t *testing.T) { expJwks := "jks-block-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) // First block - blockResp1, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp1, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockResp1.GetSuccess()) // Second block (should be idempotent) - blockResp2, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp2, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockResp2.GetSuccess()) }) - t.Run("UnblockOIDCMapping idempotent - unblocking twice", func(t *testing.T) { + t.Run("UnblockTrustMapping idempotent - unblocking twice", func(t *testing.T) { expJwks := "jks-unblock-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) // First unblock - unblockRes1, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes1, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes1.GetSuccess()) // Second unblock (should be idempotent) - unblockRes2, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes2, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes2.GetSuccess()) }) @@ -234,108 +304,119 @@ func TestGRPCServer(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - // Apply mapping - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + // Apply trust + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) // Block it - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) // Unblock it - unblockRes, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes.GetSuccess()) // Block again - blockRes2, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes2, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes2.GetSuccess()) // Unblock again - unblockRes2, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes2, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes2.GetSuccess()) }) - t.Run("RemoveOIDCMapping idempotent - removing twice", func(t *testing.T) { + t.Run("RemoveTrustMapping idempotent - removing twice", func(t *testing.T) { expJwks := "jks-remove-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) // First remove - removeRes1, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes1, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, removeRes1.GetSuccess()) // Second remove - idempotence should not cause an error - removeRes2, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes2, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, removeRes2.GetSuccess()) }) + + cancelCommand() + select { + case err := <-errCh: + if err != nil { + t.Fatalf("error executing command: %s", err) + } + case <-time.After(10 * time.Second): + t.Fatalf("timeout exceeded") + } } -func createClientConn(t *testing.T, port int) (*stdgrpc.ClientConn, error) { +func createClientConn(t *testing.T, port int) (*grpc.ClientConn, error) { t.Helper() - conn, err := stdgrpc.NewClient(fmt.Sprintf("localhost:%d", port), - stdgrpc.WithTransportCredentials(insecure.NewCredentials()), + conn, err := grpc.NewClient(fmt.Sprintf("localhost:%d", port), + grpc.WithTransportCredentials(insecure.NewCredentials()), ) return conn, err } -func startServer(t *testing.T, port int) (*stdgrpc.Server, *trust.Service, func(context.Context), error) { - t.Helper() - ctx := t.Context() - // start postgres - db, _, terminateFn := postgrestest.Start(ctx) - trustRepo := trustsql.NewRepository(db) - service := trust.NewService(trustRepo) - - lstConf := net.ListenConfig{} - lis, err := lstConf.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - return nil, nil, nil, err - } +func waitGRPCServerReady(ctx context.Context, cc *grpc.ClientConn) error { + healthClient := healthpb.NewHealthClient(cc) - srv := stdgrpc.NewServer() - oidcmappingv1.RegisterServiceServer(srv, grpc.NewOIDCMappingServer(service)) - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, nil, trustRepo, time.Hour, "")) - - // start - go func() { - err = srv.Serve(lis) + const maxAttempts = 100 + for range maxAttempts { + out, err := healthClient.Check(ctx, new(healthpb.HealthCheckRequest), grpc.WaitForReady(true)) if err != nil { - slogctx.Error(ctx, "error while starting server", "error", err) + return fmt.Errorf("checking health status: %w", err) } - }() - return srv, service, terminateFn, nil + if out.GetStatus() == healthpb.HealthCheckResponse_SERVING { + return nil + } + } + + return errors.New("exceeded max attempts number") } diff --git a/integration/infra_test.go b/integration/infra_test.go index fbb6806a..858e9d5c 100644 --- a/integration/infra_test.go +++ b/integration/infra_test.go @@ -9,25 +9,43 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/goccy/go-yaml" "github.com/moby/moby/api/types/network" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/stretchr/testify/require" + "github.com/valkey-io/valkey-go" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/dbtest/postgrestest" "github.com/openkcm/session-manager/internal/dbtest/valkeytest" + "github.com/openkcm/session-manager/modules/database/pgxpool" ) type closeFunc func(ctx context.Context) +type Config struct { + config.Config `yaml:",inline"` + + Database pgxpool.PostgresModule `yaml:"database"` +} + +func loadExtendedConfig(t *testing.T, dir string) *Config { + t.Helper() + + cfg, err := config.Load("", dir) + require.NoError(t, err, "failed to load config") + + return &Config{Config: *cfg} +} + type infraStat struct { PostgresPort network.Port ValKeyPort network.Port ConfigFilePath string Procdir string - Cfg config.Config + Cfg *Config closeFuncs []closeFunc } @@ -38,21 +56,22 @@ func initInfra(t *testing.T) (istat infraStat) { // Since the config is read from the file $PWD/config.yaml, // we're running a process in a temporary subdirectory // so that we aren't interfering with the other tests. - procDir, err := os.MkdirTemp("", "*") - require.NoError(t, err, "failed to create a temp dir") + procDir := t.TempDir() istat.Procdir = procDir istat.ConfigFilePath = filepath.Join(procDir, "config.yaml") - err = os.WriteFile(istat.ConfigFilePath, []byte(validConfig), fs.ModePerm) + err := os.WriteFile(istat.ConfigFilePath, []byte(validConfig), fs.ModePerm) require.NoError(t, err, "failed to write config file") - err = commoncfg.LoadConfig(&istat.Cfg, nil, istat.Procdir) - require.NoError(t, err, "failed to load config") + istat.Cfg = loadExtendedConfig(t, istat.Procdir) // Let OS choose a free port istat.Cfg.HTTP.Address = "unix://" + filepath.Join(procDir, "unix.sock") istat.Cfg.GRPC.Address = ":0" istat.Cfg.Logger.Format = commoncfg.TextLoggerFormat + istat.Cfg.Logger.Level = "debug" + istat.Cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + istat.Cfg.Logger.Formatter.Time.Pattern = time.Stamp // There's a hard limit of 108 symbols on a unix socket filepath on Linux/macOS. if len(istat.Cfg.HTTP.Address) > 108 { @@ -69,35 +88,33 @@ func (istat *infraStat) PreparePostgres(t *testing.T) { const dbpass = "secret" const dbname = "session_manager" - wd, err := os.Getwd() - require.NoError(t, err, "getting wd") - pgClient, pgPort, pgTerminate := postgrestest.Start(t.Context()) pgClient.Close() istat.PostgresPort = pgPort istat.closeFuncs = append(istat.closeFuncs, pgTerminate) + istat.Cfg.Database.Mod = "database.module.pgxpool" istat.Cfg.Database.Name = dbname istat.Cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} istat.Cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} istat.Cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} istat.Cfg.Database.Port = pgPort.Port() - istat.Cfg.Migrate.Source = "file://" + filepath.Join(wd, "../sql") } -func (istat *infraStat) PrepareValKey(t *testing.T) { +func (istat *infraStat) PrepareValKey(t *testing.T) valkey.Client { t.Helper() vkClient, vkPort, vkTerminate := valkeytest.Start(t.Context()) - vkClient.Close() istat.ValKeyPort = vkPort - istat.closeFuncs = append(istat.closeFuncs, vkTerminate) + istat.closeFuncs = append(istat.closeFuncs, vkTerminate, func(_ context.Context) { vkClient.Close() }) istat.Cfg.ValKey.Host = commoncfg.SourceRef{Source: "embedded", Value: net.JoinHostPort("localhost", vkPort.Port())} istat.Cfg.ValKey.User = commoncfg.SourceRef{Source: "embedded", Value: ""} istat.Cfg.ValKey.Password = commoncfg.SourceRef{Source: "embedded", Value: ""} + + return vkClient } // PrepareConfig writes a config file for running the test into the ConfigFilePath. diff --git a/integration/migrate_test.go b/integration/migrate_test.go index b21490dd..4761470f 100644 --- a/integration/migrate_test.go +++ b/integration/migrate_test.go @@ -8,13 +8,12 @@ import ( "os/exec" "path/filepath" "testing" + "time" "github.com/go-viper/mapstructure/v2" "github.com/goccy/go-yaml" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/testcontainers/testcontainers-go/modules/postgres" - - "github.com/openkcm/session-manager/internal/config" ) func TestMigrate(t *testing.T) { @@ -39,6 +38,7 @@ func TestMigrate(t *testing.T) { if err != nil { t.Fatalf("failed to start PostgreSQL: %s", err) } + defer func() { _ = pgContainer.Terminate(ctx) }() port, err := pgContainer.MappedPort(ctx, "5432") if err != nil { @@ -46,40 +46,51 @@ func TestMigrate(t *testing.T) { } // Prepare config - _ = os.MkdirAll(testdir, fs.ModePerm) - defer os.RemoveAll(testdir) - - err = os.WriteFile(configFilePath, []byte(validConfig), fs.ModePerm) + currdir, err := os.Getwd() if err != nil { - t.Fatalf("failed to write config file: %s", err) + t.Fatalf("failed to get wd: %s", err) } - defer os.Remove(configFilePath) - var cfg config.Config - err = commoncfg.LoadConfig(&cfg, nil, testdir) - if err != nil { - t.Fatalf("failed to load config: %s", err) - } + abstestdir := filepath.Join(currdir, testdir) + _ = os.MkdirAll(abstestdir, fs.ModePerm) + defer os.RemoveAll(abstestdir) - currdir, err := os.Getwd() + absConfigFilePath := filepath.Join(currdir, configFilePath) + err = os.WriteFile(absConfigFilePath, []byte(validConfig), fs.ModePerm) if err != nil { - t.Fatalf("failed to get wd: %s", err) + t.Fatalf("failed to write config file: %s", err) } + cfg := loadExtendedConfig(t, abstestdir) + cfg.Logger.Level = "debug" + cfg.Logger.Format = commoncfg.TextLoggerFormat + cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + cfg.Logger.Formatter.Time.Pattern = time.Stamp + cfg.Database.Name = dbname cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} cfg.Database.Port = port.Port() - cfg.Migrate.Source = "file://" + filepath.Join(currdir, "../sql") cfgMap := make(map[any]any) - err = mapstructure.Decode(cfg, &cfgMap) - if err != nil { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &cfgMap, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + WeaklyTypedInput: true, + TagName: "yaml", + SquashTagOption: "inline", + }) + if err != nil { + t.Fatalf("failed to create mapstructure decoder: %s", err) + } + if err := decoder.Decode(cfg); err != nil { t.Fatalf("failed to decode mapstructure: %s", err) } - f, err := os.Create(configFilePath) + f, err := os.Create(absConfigFilePath) if err != nil { t.Fatalf("failed to create config file: %s", err) } @@ -90,9 +101,7 @@ func TestMigrate(t *testing.T) { t.Fatalf("failed to write config: %s", err) } - wd, _ := os.Getwd() - t.Chdir(testdir) - defer os.Chdir(wd) + t.Chdir(abstestdir) // Run the migrations cmd := exec.CommandContext(ctx, filepath.Join(currdir, "./session-manager"), cmdName) @@ -145,40 +154,51 @@ func TestMigrateIdempotent(t *testing.T) { } // Prepare config - _ = os.MkdirAll(testdir, fs.ModePerm) - defer os.RemoveAll(testdir) - - err = os.WriteFile(configFilePath, []byte(validConfig), fs.ModePerm) + currdir, err := os.Getwd() if err != nil { - t.Fatalf("failed to write config file: %s", err) + t.Fatalf("failed to get wd: %s", err) } - defer os.Remove(configFilePath) - var cfg config.Config - err = commoncfg.LoadConfig(&cfg, nil, testdir) - if err != nil { - t.Fatalf("failed to load config: %s", err) - } + abstestdir := filepath.Join(currdir, testdir) + _ = os.MkdirAll(abstestdir, fs.ModePerm) + defer os.RemoveAll(abstestdir) - currdir, err := os.Getwd() + absConfigFilePath := filepath.Join(currdir, configFilePath) + err = os.WriteFile(absConfigFilePath, []byte(validConfig), fs.ModePerm) if err != nil { - t.Fatalf("failed to get wd: %s", err) + t.Fatalf("failed to write config file: %s", err) } + cfg := loadExtendedConfig(t, abstestdir) + cfg.Logger.Level = "debug" + cfg.Logger.Format = commoncfg.TextLoggerFormat + cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + cfg.Logger.Formatter.Time.Pattern = time.Stamp + cfg.Database.Name = dbname cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} cfg.Database.Port = port.Port() - cfg.Migrate.Source = "file://" + filepath.Join(currdir, "../sql") cfgMap := make(map[any]any) - err = mapstructure.Decode(cfg, &cfgMap) - if err != nil { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &cfgMap, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + WeaklyTypedInput: true, + TagName: "yaml", + SquashTagOption: "inline", + }) + if err != nil { + t.Fatalf("failed to create mapstructure decoder: %s", err) + } + if err := decoder.Decode(cfg); err != nil { t.Fatalf("failed to decode mapstructure: %s", err) } - f, err := os.Create(configFilePath) + f, err := os.Create(absConfigFilePath) if err != nil { t.Fatalf("failed to create config file: %s", err) } @@ -189,9 +209,7 @@ func TestMigrateIdempotent(t *testing.T) { t.Fatalf("failed to write config: %s", err) } - wd, _ := os.Getwd() - t.Chdir(testdir) - defer os.Chdir(wd) + t.Chdir(abstestdir) // Run migrations the first time cmd := exec.CommandContext(ctx, filepath.Join(currdir, "./session-manager"), cmdName) diff --git a/integration/session_grpc_test.go b/integration/session_grpc_test.go index 7a235e4e..f9f92258 100644 --- a/integration/session_grpc_test.go +++ b/integration/session_grpc_test.go @@ -4,8 +4,11 @@ package integration_test import ( "context" + "errors" "fmt" - "net" + "os" + "os/exec" + "path/filepath" "testing" "time" @@ -14,33 +17,76 @@ import ( "github.com/stretchr/testify/require" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - slogctx "github.com/veqryn/slog-context" - stdgrpc "google.golang.org/grpc" - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/dbtest/valkeytest" - "github.com/openkcm/session-manager/internal/grpc" "github.com/openkcm/session-manager/internal/session" sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" - "github.com/openkcm/session-manager/internal/trust/trustsql" ) func TestSessionGRPC(t *testing.T) { + const cmdName = "api-server" + const port = 9092 + // given ctx := t.Context() - port := 9092 - // create grpc server with session support - srv, sessionRepo, terminateFn, err := startSessionServer(t, port) - require.NoError(t, err) - defer srv.Stop() - defer terminateFn(ctx) + istat := initInfra(t) + defer istat.Close(ctx) + + istat.Cfg.GRPC.Address = fmt.Sprintf(":%d", port) + + istat.PreparePostgres(t) + valkeyClient := istat.PrepareValKey(t) + istat.PrepareConfig(t) + + sessionRepo := sessionvalkey.NewRepository(valkeyClient, "session") + + currdir, err := os.Getwd() + require.NoError(t, err, "failed to get wd") + + t.Chdir(istat.Procdir) + + commandCtx, cancelCommand := context.WithCancel(ctx) + defer cancelCommand() + + cmd := exec.CommandContext(commandCtx, filepath.Join(currdir, "./session-manager"), cmdName) + cmd.WaitDelay = 5 * time.Second + cmd.Cancel = func() error { return cmd.Process.Signal(os.Interrupt) } + + cmdOutPath := filepath.Join(currdir, "session-grpc.log") + cmdOut, err := os.Create(cmdOutPath) + if err != nil { + t.Fatalf("failed to create an log file") + } + defer cmdOut.Close() + + cmd.Stdout = cmdOut + cmd.Stderr = cmdOut + + t.Logf("starting an app process. Logs will be saved into %s", cmdOutPath) + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start the server: %s", err) + } + + errCh := make(chan error) + go func() { + if err := cmd.Wait(); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("executing command: %w", err) + } + close(errCh) + }() // grpc client connection - conn, err := createClientConn(t, port) + cc, err := createClientConn(t, port) require.NoError(t, err) - defer conn.Close() - sessionClient := sessionv1.NewServiceClient(conn) + defer cc.Close() + + waitCtx, cancelWait := context.WithTimeout(commandCtx, 10*time.Second) + defer cancelWait() + if err := waitGRPCServerReady(waitCtx, cc); err != nil { + t.Fatalf("waiting for the server readiness: %s", err) + } + + sessionClient := sessionv1.NewServiceClient(cc) t.Run("GetSession - session not found", func(t *testing.T) { resp, err := sessionClient.GetSession(ctx, &sessionv1.GetSessionRequest{ @@ -106,7 +152,7 @@ func TestSessionGRPC(t *testing.T) { err = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) require.NoError(t, err) - // Note: This test will fail validation because there's no trust mapping configured + // Note: This test will fail validation because there's no trust configured // but it tests the session retrieval path resp, err := sessionClient.GetSession(ctx, &sessionv1.GetSessionRequest{ SessionId: sess.ID, @@ -115,7 +161,7 @@ func TestSessionGRPC(t *testing.T) { }) assert.NoError(t, err) assert.NotNil(t, resp) - // Will be false because trust mapping is not configured, but tests the flow + // Will be false because trust is not configured, but tests the flow assert.False(t, resp.GetValid()) }) @@ -171,43 +217,3 @@ func TestSessionGRPC(t *testing.T) { assert.False(t, resp.GetValid()) }) } - -func startSessionServer(t *testing.T, port int) (*stdgrpc.Server, session.Repository, func(context.Context), error) { - t.Helper() - ctx := t.Context() - - // start postgres - db, _, terminatePG := postgrestest.Start(ctx) - - // start valkey - valkeyClient, _, terminateValkey := valkeytest.Start(ctx) - - terminateFn := func(ctx context.Context) { - terminatePG(ctx) - terminateValkey(ctx) - db.Close() - valkeyClient.Close() - } - - trustRepo := trustsql.NewRepository(db) - sessionRepo := sessionvalkey.NewRepository(valkeyClient, "session") - - lstConf := net.ListenConfig{} - lis, err := lstConf.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - return nil, nil, nil, err - } - - srv := stdgrpc.NewServer() - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "")) - - // start - go func() { - err = srv.Serve(lis) - if err != nil { - slogctx.Error(ctx, "error while starting session server", "error", err) - } - }() - - return srv, sessionRepo, terminateFn, nil -} diff --git a/internal/business/apps.go b/internal/business/apps.go new file mode 100644 index 00000000..45313a2a --- /dev/null +++ b/internal/business/apps.go @@ -0,0 +1,103 @@ +package business + +import ( + "errors" + "fmt" + "slices" + + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +// startedApp pairs a configured name with its loaded handle so we can stop +// apps in reverse start order and emit meaningful logs. +type startedApp struct { + name string + app sessionmanager.App +} + +// startApps loads, provisions and starts every app declared under the +// top-level apps: section in cfg, in the order given by cfg.AppsOrder (with +// any apps not listed there appended in cfg.Apps map iteration order). +// +// If any Start() returns a non-nil error, every previously-started app is +// stopped in reverse order before the original error is returned. On +// success, the returned stopAll closure stops every started app in reverse +// order and joins any Stop() errors via errors.Join. +func startApps(ctx *sessionmanager.Context, cfg *config.Config) (stopAll func() error, _ error) { + order, err := appsStartOrder(cfg) + if err != nil { + return nil, err + } + + started := make([]startedApp, 0, len(order)) + + rollback := func() { + for _, sa := range slices.Backward(started) { + if stopErr := sa.app.Stop(); stopErr != nil { + slogctx.Error(ctx, "stopping app during rollback", "app", sa.name, "error", stopErr) + } + } + } + + for _, name := range order { + appCfg := cfg.Apps[name] + + slogctx.Info(ctx, "loading app", "app", name, "module", appCfg.Module()) + app, err := ctx.LoadApp(appCfg) + if err != nil { + rollback() + return nil, fmt.Errorf("loading app %q: %w", name, err) + } + + slogctx.Info(ctx, "starting app", "app", name) + if err := app.Start(); err != nil { + rollback() + return nil, fmt.Errorf("starting app %q: %w", name, err) + } + + started = append(started, startedApp{name: name, app: app}) + } + + return func() error { + var errs []error + for _, sa := range slices.Backward(started) { + slogctx.Info(ctx, "stopping app", "app", sa.name) + if err := sa.app.Stop(); err != nil { + slogctx.Error(ctx, "stopping app", "app", sa.name, "error", err) + errs = append(errs, fmt.Errorf("stopping app %q: %w", sa.name, err)) + } + } + return errors.Join(errs...) + }, nil +} + +// appsStartOrder returns the names of configured apps in start order. Names +// listed in cfg.AppsOrder come first, in the given order; the remainder are +// appended in cfg.Apps map iteration order. Names in AppsOrder that are not +// present in cfg.Apps surface as a configuration error. +func appsStartOrder(cfg *config.Config) ([]string, error) { + seen := make(map[string]bool, len(cfg.Apps)) + order := make([]string, 0, len(cfg.Apps)) + + for _, name := range cfg.AppsOrder { + if _, ok := cfg.Apps[name]; !ok { + return nil, fmt.Errorf("appsOrder references unknown app %q", name) + } + if !seen[name] { + seen[name] = true + order = append(order, name) + } + } + + for name := range cfg.Apps { + if !seen[name] { + seen[name] = true + order = append(order, name) + } + } + + return order, nil +} diff --git a/internal/business/apps_test.go b/internal/business/apps_test.go new file mode 100644 index 00000000..bd46afba --- /dev/null +++ b/internal/business/apps_test.go @@ -0,0 +1,198 @@ +package business + +import ( + "errors" + "fmt" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +// fakeApp is an App whose Start/Stop record their relative order across all +// fakeApp instances sharing a counter, so tests can assert ordering. +type fakeApp struct { + id string + counter *atomic.Int64 + startErr error + stopErr error + startOrder int64 + stopOrder int64 + startCalled bool + stopCalled bool +} + +func (a *fakeApp) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: a.id, + New: func() sessionmanager.Module { return a }, + } +} + +func (a *fakeApp) Start() error { + a.startCalled = true + a.startOrder = a.counter.Add(1) + return a.startErr +} + +func (a *fakeApp) Stop() error { + a.stopCalled = true + a.stopOrder = a.counter.Add(1) + return a.stopErr +} + +// registerFakeApps registers the given apps under unique module IDs scoped to +// the test name and returns a config.Config wired to load them in the order +// supplied. The returned cfg.AppsOrder pins the start order so tests don't +// rely on Go's map iteration order. +func registerFakeApps(t *testing.T, apps ...*fakeApp) *config.Config { + t.Helper() + + cfg := &config.Config{ + Apps: make(map[string]*config.App, len(apps)), + AppsOrder: make([]string, 0, len(apps)), + } + + for i, app := range apps { + modID := fmt.Sprintf("test.app.%s.%d", t.Name(), i) + app.id = modID + + // Capture for closure — RegisterModule's New() must hand back this + // instance so the test can assert against it. + captured := app + sessionmanager.RegisterModule(&moduleInfoStub{ + id: modID, + new: func() sessionmanager.Module { return captured }, + }) + + name := fmt.Sprintf("app-%d", i) + cfg.Apps[name] = &config.App{Mod: modID} + cfg.AppsOrder = append(cfg.AppsOrder, name) + } + + return cfg +} + +// moduleInfoStub adapts a (id, new) pair into the Module + ModuleInfo +// registration surface required by sessionmanager.RegisterModule. +type moduleInfoStub struct { + id string + new func() sessionmanager.Module +} + +func (s *moduleInfoStub) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: s.id, New: s.new} +} + +func TestStartApps_OrderAndReverseStop(t *testing.T) { + var counter atomic.Int64 + a := &fakeApp{counter: &counter} + b := &fakeApp{counter: &counter} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.NoError(t, err) + require.True(t, a.startCalled) + require.True(t, b.startCalled) + assert.Less(t, a.startOrder, b.startOrder, "A.Start must be called before B.Start") + + require.NoError(t, stop()) + assert.Less(t, b.stopOrder, a.stopOrder, "B.Stop must be called before A.Stop") +} + +func TestStartApps_StartFailureRollsBack(t *testing.T) { + var counter atomic.Int64 + wantErr := errors.New("boom") + a := &fakeApp{counter: &counter} + b := &fakeApp{counter: &counter, startErr: wantErr} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.Error(t, err) + require.ErrorIs(t, err, wantErr) + assert.Nil(t, stop) + + assert.True(t, a.startCalled) + assert.True(t, a.stopCalled, "successfully-started app must be rolled back on Start failure") + assert.True(t, b.startCalled) + assert.False(t, b.stopCalled, "failed-to-start app must not have Stop called") +} + +func TestStartApps_FirstAppFailsNoStop(t *testing.T) { + var counter atomic.Int64 + wantErr := errors.New("boom") + a := &fakeApp{counter: &counter, startErr: wantErr} + b := &fakeApp{counter: &counter} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := startApps(ctx, cfg) + require.ErrorIs(t, err, wantErr) + assert.False(t, a.stopCalled) + assert.False(t, b.startCalled) + assert.False(t, b.stopCalled) +} + +func TestStartApps_StopErrorsAggregated(t *testing.T) { + var counter atomic.Int64 + stopErrA := errors.New("stop-a") + stopErrB := errors.New("stop-b") + a := &fakeApp{counter: &counter, stopErr: stopErrA} + b := &fakeApp{counter: &counter, stopErr: stopErrB} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.NoError(t, err) + + err = stop() + require.Error(t, err) + assert.ErrorIs(t, err, stopErrA) + assert.ErrorIs(t, err, stopErrB) +} + +func TestAppsStartOrder_AppsOrderTakesPrecedence(t *testing.T) { + cfg := &config.Config{ + Apps: map[string]*config.App{ + "a": {}, "b": {}, "c": {}, + }, + AppsOrder: []string{"c", "a"}, + } + + got, err := appsStartOrder(cfg) + require.NoError(t, err) + require.Len(t, got, 3) + assert.Equal(t, "c", got[0]) + assert.Equal(t, "a", got[1]) + // "b" wasn't listed, so it appears last in map iteration order. + assert.Equal(t, "b", got[2]) +} + +func TestAppsStartOrder_RejectsUnknownAppName(t *testing.T) { + cfg := &config.Config{ + Apps: map[string]*config.App{"a": {}}, + AppsOrder: []string{"missing"}, + } + + _, err := appsStartOrder(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing") +} diff --git a/internal/business/business.go b/internal/business/business.go index 1a9780d2..75ad2665 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -4,65 +4,73 @@ import ( "context" "errors" "fmt" - "log/slog" "sync" - "github.com/exaring/otelpgx" - "github.com/jackc/pgx/v5/pgxpool" "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/valkey-io/valkey-go" - otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/business/server" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/credentials" - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/session" - sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" + "github.com/openkcm/session-manager/internal/sessionwiring" ) -const clientAuthTypeInsecure = "insecure" - -// Main starts both API servers +// Main starts the public HTTP API server and the configured apps. func Main(ctx context.Context, cfg *config.Config) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + c, cancelCause := sessionmanager.NewContext(ctx) + defer cancelCause(nil) + + c = config.WithContext(c, cfg) + + if _, err := c.LoadModule(&cfg.Database); err != nil { + return fmt.Errorf("loading database module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Trust); err != nil { + return fmt.Errorf("loading trust module: %w", err) + } - // errChan is used to capture the first error and shutdown the servers. + if _, err := c.LoadModule(&cfg.ValKey); err != nil { + return fmt.Errorf("loading session-store module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Credentials); err != nil { + return fmt.Errorf("loading credentials module: %w", err) + } + + stopApps, err := startApps(c, cfg) + if err != nil { + return fmt.Errorf("starting apps: %w", err) + } + + // errChan captures the first error and triggers shutdown. errChan := make(chan error, 1) - // wg is used to wait for all servers to shutdown. + // wg is used to wait for all goroutines to shutdown. var wg sync.WaitGroup // start public HTTP REST API server wg.Go(func() { - errChan <- publicMain(ctx, cfg) - }) - - // start internal gRPC API server - wg.Go(func() { - errChan <- internalMain(ctx, cfg) + errChan <- publicMain(c, cfg) }) // wait for any error to initiate the shutdown - err := <-errChan + err = <-errChan if err != nil { slogctx.Error(ctx, "Shutting down servers", "error", err) } - cancel() - // wait for all servers to shutdown + stopErr := stopApps() + cancelCause(err) + wg.Wait() - return err + return errors.Join(err, stopErr) } // publicMain starts the HTTP REST public API server. -func publicMain(ctx context.Context, cfg *config.Config) error { +func publicMain(ctx *sessionmanager.Context, cfg *config.Config) error { csrfSecret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.CSRFSecret) if err != nil { return fmt.Errorf("loading csrf token from source ref: %w", err) @@ -73,174 +81,20 @@ func publicMain(ctx context.Context, cfg *config.Config) error { cfg.SessionManager.CSRFSecretParsed = csrfSecret - sessionManager, closeFn, err := initSessionManager(ctx, cfg) - if err != nil { - return fmt.Errorf("failed to initialise the session manager: %w", err) - } - - defer closeFn() - - return server.StartHTTPServer(ctx, cfg, sessionManager) -} - -// internalMain starts the gRPC private API server. -func internalMain(ctx context.Context, cfg *config.Config) error { - // Create trust service - trustRepo, err := trustRepoFromConfig(ctx, cfg) - if err != nil { - return fmt.Errorf("failed to create trust service: %w", err) - } - trustService := trust.NewService(trustRepo) - - // Create session repository - valkeyClient, err := valkeyClientFromConfig(cfg) - if err != nil { - return fmt.Errorf("failed to create valkey client: %w", err) - } - defer valkeyClient.Close() - sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - - credsBuilder, err := newCredsBuilder(cfg) - if err != nil { - return fmt.Errorf("failed to create a credentials builder: %w", err) - } - - // Initialize the gRPC servers. - oidcmappingsrv := grpc.NewOIDCMappingServer(trustService) - sessionsrv := grpc.NewSessionServer(ctx, - sessionRepo, - trustRepo, - cfg.SessionManager.IdleSessionTimeout, - cfg.SessionManager.ClientAuth.ClientID, - grpc.WithQueryParametersIntrospect(cfg.SessionManager.AdditionalQueryParametersIntrospect), - grpc.WithTransportCredentials(credsBuilder), - ) - return server.StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) -} - -func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Manager, closeFn func(), _ error) { - // Create trust repository - trustRepo, err := trustRepoFromConfig(ctx, cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to create trust repository: %w", err) - } - - // Create session repository - valkeyClient, err := valkeyClientFromConfig(cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to create valkey client: %w", err) - } - sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - - credsBuilder, err := newCredsBuilder(cfg) + trustMod, err := ctx.GetModule(cfg.Trust.Module()) if err != nil { - return nil, nil, fmt.Errorf("failed to load http client: %w", err) + return fmt.Errorf("getting trust module: %w", err) } - auditLogger, err := otlpaudit.NewLogger(&cfg.Audit) - if err != nil { - return nil, nil, fmt.Errorf("failed to create audit logger: %w", err) - } + //nolint:forcetypeassert + trust := trustMod.(sessionmanager.Trust) - sessManager, err := session.NewManager(ctx, - &cfg.SessionManager, - trustRepo, - sessionRepo, - auditLogger, - session.WithTransportCredentials(credsBuilder), - ) + sessionManager, closeFn, err := sessionwiring.InitSessionManager(ctx, cfg, trust) if err != nil { - return nil, nil, fmt.Errorf("failed to create session manager: %w", err) - } - - return sessManager, valkeyClient.Close, nil -} - -func trustRepoFromConfig(ctx context.Context, cfg *config.Config) (*trustsql.Repository, error) { - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return nil, fmt.Errorf("failed to make dsn from config: %w", err) - } - - pgxpoolCfg, err := pgxpool.ParseConfig(connStr) - if err != nil { - return nil, fmt.Errorf("parsing pgxpool config: %w", err) - } - - pgxpoolCfg.ConnConfig.Tracer = otelpgx.NewTracer() - - db, err := pgxpool.NewWithConfig(ctx, pgxpoolCfg) - if err != nil { - return nil, fmt.Errorf("failed to initialise pgxpool connection: %w", err) - } - - if err := otelpgx.RecordStats(db); err != nil { - return nil, fmt.Errorf("recording database stat: %w", err) - } - - return trustsql.NewRepository(db), nil -} - -func valkeyClientFromConfig(cfg *config.Config) (valkey.Client, error) { - valkeyHost, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Host) - if err != nil { - return nil, fmt.Errorf("failed to load valkey host: %w", err) - } - - valkeyUsername, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.User) - if err != nil { - return nil, fmt.Errorf("failed to load valkey username: %w", err) - } - - valkeyPassword, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Password) - if err != nil { - return nil, fmt.Errorf("failed to load valkey password: %w", err) - } - - valkeyOpts := valkey.ClientOption{ - InitAddress: []string{string(valkeyHost)}, - Username: string(valkeyUsername), - Password: string(valkeyPassword), - } - - if cfg.ValKey.SecretRef.Type == commoncfg.MTLSSecretType { - tlsConfig, err := commoncfg.LoadMTLSConfig(&cfg.ValKey.SecretRef.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load valkey mTLS config from secret ref: %w", err) - } - - valkeyOpts.TLSConfig = tlsConfig + return fmt.Errorf("failed to initialise the session manager: %w", err) } - valkeyClient, err := valkey.NewClient(valkeyOpts) - if err != nil { - return nil, fmt.Errorf("failed to create a new valkey client: %w", err) - } - return valkeyClient, nil -} + defer closeFn() -func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { - switch cfg.SessionManager.ClientAuth.Type { - case "mtls": - tlsConfig, err := commoncfg.LoadMTLSConfig(cfg.SessionManager.ClientAuth.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load mTLS config: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil - case "client_secret", "client_secret_post": - secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) - if err != nil { - return nil, fmt.Errorf("failed to load client secret: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { - return credentials.NewClientSecretPost(clientID, string(secret)) - }, nil - case clientAuthTypeInsecure: - slog.Warn("insecure credentials are used. Do not use this in production") - return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil - default: - return nil, errors.New("unknown Client Auth type") - } + return server.StartHTTPServer(ctx, cfg, sessionManager) } diff --git a/internal/business/business_test.go b/internal/business/business_test.go index fa9ecb1d..c2519fcc 100644 --- a/internal/business/business_test.go +++ b/internal/business/business_test.go @@ -1,366 +1,15 @@ package business import ( - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" "testing" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/credentials" ) -func TestLoadHTTPClient_MTLS(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "mtls", - ClientID: "test-client", - MTLS: &commoncfg.MTLS{ - Cert: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - // This will fail without actual cert files, but tests the logic path - _, err := newCredsBuilder(cfg) - // We expect an error since we don't have real cert files - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load mTLS config") -} - -func TestLoadHTTPClient_ClientSecret(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "client_secret", - ClientID: "test-client", - ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "test-secret"}, - }, - }, - } - - builder, err := newCredsBuilder(cfg) - require.NoError(t, err) - require.NotNil(t, builder) - - // Verify it's using our custom transport - creds := builder(cfg.SessionManager.ClientAuth.ClientID) - clientSecretCreds, ok := creds.(*credentials.ClientSecretPost) - require.True(t, ok) - - assert.Equal(t, "test-client", clientSecretCreds.ClientID) - assert.Equal(t, "test-secret", clientSecretCreds.ClientSecret) -} - -func TestLoadHTTPClient_Insecure(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, - ClientID: "test-client", - }, - }, - } - - builder, err := newCredsBuilder(cfg) - require.NoError(t, err) - assert.IsType(t, &credentials.Insecure{}, builder("")) -} - -func TestLoadHTTPClient_UnknownType(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "unknown", - ClientID: "test-client", - }, - }, - } - - _, err := newCredsBuilder(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown Client Auth type") -} - -func TestClientAuthRoundTripper_RoundTrip(t *testing.T) { - tests := []struct { - name string - clientID string - clientSecret string - requestURL string - expectedClientID string - expectedHasSecret bool - expectedSecretVal string - body io.Reader - }{ - { - name: "With client secret", - clientID: "my-client", - clientSecret: "my-secret", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: true, - expectedSecretVal: "my-secret", - }, - { - name: "Without client secret", - clientID: "my-client", - clientSecret: "", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: false, - }, - { - name: "With existing query params", - clientID: "my-client", - clientSecret: "my-secret", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: true, - expectedSecretVal: "my-secret", - body: strings.NewReader("foo=bar"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Helper() - - // Create a test server that captures the request - var capturedReq *http.Request - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = r.ParseForm() - capturedReq = r - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - // Create the round tripper - creds := credentials.NewClientSecretPost(tt.clientID, tt.clientSecret) - - // Parse the test URL - reqURL, err := url.Parse(tt.requestURL) - require.NoError(t, err) - - // Update URL to point to test server - reqURL.Scheme = "http" - reqURL.Host = server.Listener.Addr().String() - - // Create and execute request - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, reqURL.String(), tt.body) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, err) - - resp, err := creds.Transport().RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, resp) - defer resp.Body.Close() - - // Verify the captured request has correct query params - require.NotNil(t, capturedReq) - - assert.Equal(t, tt.expectedClientID, capturedReq.FormValue("client_id")) - - if tt.expectedHasSecret { - assert.Equal(t, tt.expectedSecretVal, capturedReq.FormValue("client_secret")) - } else { - assert.Empty(t, capturedReq.FormValue("client_secret")) - } - - // Verify original query params are preserved - if tt.body != nil { - b, _ := io.ReadAll(tt.body) - q, _ := url.ParseQuery(string(b)) - - for k, v := range q { - assert.Equal(t, v, capturedReq.FormValue(k)) - } - } - }) - } -} - -func TestClientAuthRoundTripper_RoundTrip_PreservesExistingParams(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "my-client", r.FormValue("client_id")) - assert.Equal(t, "my-secret", r.FormValue("client_secret")) - assert.Equal(t, "bar", r.FormValue("foo")) - assert.Equal(t, "baz", r.FormValue("param2")) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - creds := credentials.NewClientSecretPost("my-client", "my-secret") - - reqURL := server.URL - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, reqURL, strings.NewReader("foo=bar¶m2=baz")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, err) - - resp, err := creds.Transport().RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, resp) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func TestValkeyClientFromConfig_InvalidHostRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey host") -} - -func TestValkeyClientFromConfig_InvalidUserRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey username") -} - -func TestValkeyClientFromConfig_InvalidPasswordRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey password") -} - -func TestValkeyClientFromConfig_WithMTLS(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - SecretRef: commoncfg.SecretRef{ - Type: commoncfg.MTLSSecretType, - MTLS: commoncfg.MTLS{ - Cert: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey mTLS config from secret ref") -} - -func TestTrustRepoFromConfig_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := trustRepoFromConfig(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to make dsn from config") -} - -func TestInitSessionManager_InvalidOIDCConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - assert.Contains(t, err.Error(), "failed to create trust repository") -} - -func TestInitSessionManager_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - // Will fail on either DB connection or valkey config - // Error details depend on which step fails -} - -func TestInitSessionManager_InvalidHTTPClientConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "invalid-type", - }, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - // Should fail on one of the earlier steps (DB or valkey) or on HTTP client - // Error details depend on which step fails -} - func TestPublicMain_InvalidCSRFSecret(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ @@ -368,7 +17,10 @@ func TestPublicMain_InvalidCSRFSecret(t *testing.T) { }, } - err := publicMain(t.Context(), cfg) + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + err := publicMain(ctx, cfg) assert.Error(t, err) assert.Contains(t, err.Error(), "loading csrf token from source ref") } @@ -380,48 +32,12 @@ func TestPublicMain_ShortCSRFSecret(t *testing.T) { }, } - err := publicMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "CSRF secret must be at least 32 bytes") -} - -func TestInternalMain_InvalidOIDCConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := internalMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create trust service") -} - -func TestInternalMain_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) - err := internalMain(t.Context(), cfg) + err := publicMain(ctx, cfg) assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create valkey client") - // Could fail on OIDC (DB connection) or valkey - // Error details depend on which step fails + assert.Contains(t, err.Error(), "CSRF secret must be at least 32 bytes") } func TestMain_InvalidCSRFSecret(t *testing.T) { @@ -440,39 +56,6 @@ func TestMain_PublicServerInvalidCSRF(t *testing.T) { SessionManager: config.SessionManager{ CSRFSecret: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, }, - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := Main(t.Context(), cfg) - assert.Error(t, err) -} - -func TestMain_InternalServerInvalidDatabase(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - CSRFSecret: commoncfg.SourceRef{Source: "embedded", Value: "this-is-a-very-long-secret-that-is-at-least-32-bytes-long"}, - ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, - }, - }, - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, } err := Main(t.Context(), cfg) diff --git a/internal/business/housekeeper.go b/internal/business/housekeeper.go index 197fc42a..fde823a5 100644 --- a/internal/business/housekeeper.go +++ b/internal/business/housekeeper.go @@ -7,31 +7,59 @@ import ( slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/sessionwiring" ) // HousekeeperMain starts the house keeping jobs func HousekeeperMain(ctx context.Context, cfg *config.Config) error { - sessionManager, closeFn, err := initSessionManager(ctx, cfg) + c, cancelCause := sessionmanager.NewContext(ctx) + defer cancelCause(nil) + + c = config.WithContext(c, cfg) + + _, err := c.LoadModule(&cfg.Database) + if err != nil { + return fmt.Errorf("loading database module: %w", err) + } + + trustMod, err := c.LoadModule(&cfg.Trust) + if err != nil { + return fmt.Errorf("loading trust module: %w", err) + } + + if _, err := c.LoadModule(&cfg.ValKey); err != nil { + return fmt.Errorf("loading session-store module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Credentials); err != nil { + return fmt.Errorf("loading credentials module: %w", err) + } + + //nolint:forcetypeassert + trust := trustMod.(sessionmanager.Trust) + + sessionManager, closeFn, err := sessionwiring.InitSessionManager(c, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } defer closeFn() // Start the housekeeper loop - c := time.Tick(cfg.Housekeeper.TriggerInterval) + tick := time.Tick(cfg.Housekeeper.TriggerInterval) refreshTriggerInterval := cfg.Housekeeper.TokenRefreshTriggerInterval concurrencyLimit := cfg.Housekeeper.ConcurrencyLimit for { - err := sessionManager.TriggerHousekeeping(ctx, concurrencyLimit, refreshTriggerInterval) + err := sessionManager.TriggerHousekeeping(c, concurrencyLimit, refreshTriggerInterval) if err != nil { slogctx.Error(ctx, "Error during session housekeeping", "error", err) } select { - case <-c: + case <-tick: continue - case <-ctx.Done(): + case <-c.Done(): return nil } } diff --git a/internal/business/housekeeper_test.go b/internal/business/housekeeper_test.go index 820ca9a4..984837ed 100644 --- a/internal/business/housekeeper_test.go +++ b/internal/business/housekeeper_test.go @@ -10,52 +10,8 @@ import ( "github.com/openkcm/session-manager/internal/config" ) -func TestHousekeeperMain_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := HousekeeperMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to initialise the session manager") -} - -func TestHousekeeperMain_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := HousekeeperMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to initialise the session manager") -} - func TestHousekeeperMain_CancelledContext(t *testing.T) { cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, ValKey: config.ValKey{ Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, @@ -63,7 +19,7 @@ func TestHousekeeperMain_CancelledContext(t *testing.T) { }, SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, + Type: "insecure", }, }, } diff --git a/internal/business/migrate.go b/internal/business/migrate.go index 9c70059d..a03d5271 100644 --- a/internal/business/migrate.go +++ b/internal/business/migrate.go @@ -4,57 +4,41 @@ import ( "context" "fmt" - "github.com/XSAM/otelsql" - "github.com/pressly/goose/v3" - "github.com/samber/oops" - // Register pgx driver _ "github.com/jackc/pgx/v5/stdlib" slogctx "github.com/veqryn/slog-context" - semconv "go.opentelemetry.io/otel/semconv/v1.37.0" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" - migrations "github.com/openkcm/session-manager/sql" ) // MigrateMain starts the database migration func MigrateMain(ctx context.Context, cfg *config.Config) error { - const dialect = "pgx" - dbSystemName := semconv.DBSystemNamePostgreSQL - - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return fmt.Errorf("making connection string from config: %w", err) - } - - db, err := otelsql.Open(dialect, connStr, otelsql.WithAttributes(dbSystemName)) - if err != nil { - return oops.In("main").Wrapf(err, "opening DB connection") - } - - reg, err := otelsql.RegisterDBStatsMetrics(db, otelsql.WithAttributes(dbSystemName)) - if err != nil { - return fmt.Errorf("registering db stats metrics: %w", err) - } + c, cancel := sessionmanager.NewContext(ctx) + var err error defer func() { - err = reg.Unregister() - if err != nil { - slogctx.Error(ctx, "failed to unregister db stats metrics", "error", err) - } + cancel(err) }() - goose.SetBaseFS(migrations.FS) - - err = goose.SetDialect(dialect) + slogctx.Debug(c, "loading db") + _, err = c.LoadModule(&cfg.Database) if err != nil { - return fmt.Errorf("setting goose dialect: %w", err) + return fmt.Errorf("loading database module: %w", err) } - err = goose.UpContext(ctx, db, ".") + slogctx.Debug(c, "loading migrate") + mod, err := c.LoadModule(&cfg.Migrate) if err != nil { - return fmt.Errorf("applying migrations: %w", err) + return fmt.Errorf("loading migration module: %w", err) + } + + //nolint:forcetypeassert + migrate := mod.(sessionmanager.Migrate) + slogctx.Debug(c, "executing migration") + if err := migrate.Migrate(ctx); err != nil { + return fmt.Errorf("executing migrations: %w", err) } return nil diff --git a/internal/business/migrate_test.go b/internal/business/migrate_test.go deleted file mode 100644 index ea28d0b7..00000000 --- a/internal/business/migrate_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package business - -import ( - "testing" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/config" -) - -func TestMigrateMain_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} - -func TestMigrateMain_InvalidUserRef(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} - -func TestMigrateMain_InvalidPasswordRef(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} diff --git a/internal/business/server/grpc_server.go b/internal/business/server/grpc_server.go deleted file mode 100644 index 64bfccae..00000000 --- a/internal/business/server/grpc_server.go +++ /dev/null @@ -1,59 +0,0 @@ -package server - -import ( - "context" - "net" - - "github.com/openkcm/common-sdk/pkg/commongrpc" - "github.com/samber/oops" - - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/grpc" -) - -func StartGRPCServer(ctx context.Context, cfg *config.Config, - oidcmappingsrv *grpc.OIDCMappingServer, - sessionsrv *grpc.SessionServer, -) error { - grpcServer := commongrpc.NewServer(ctx, &cfg.GRPC.GRPCServer) - - // Register OIDC mapping server for the regional tenant manager - oidcmappingv1.RegisterServiceServer(grpcServer, oidcmappingsrv) - // Register Session server for ExtAuthZ - sessionv1.RegisterServiceServer(grpcServer, sessionsrv) - - slogctx.Info(ctx, "Starting a listener", "address", cfg.GRPC.Address) - - listener, err := new(net.ListenConfig).Listen(ctx, "tcp", cfg.GRPC.Address) - if err != nil { - return oops.In("gRPC Server"). - WithContext(ctx). - Wrapf(err, "creating listener") - } - - slogctx.Info(ctx, "A listener started", "address", listener.Addr().String()) - - go func() { - slogctx.Info(ctx, "Serving a gRPC server", "address", listener.Addr().String()) - err := grpcServer.Serve(listener) - if err != nil { - slogctx.Error(ctx, "Failed to serve gRPC endpoint", "error", err) - } - - slogctx.Info(ctx, "Stopped gRPC server") - }() - - <-ctx.Done() - - shutdownCtx, cancel := context.WithTimeout(ctx, cfg.GRPC.ShutdownTimeout) - defer cancel() - - grpcServer.GracefulStop() - slogctx.Info(shutdownCtx, "Completed graceful shutdown of gRPC server") - - return nil -} diff --git a/internal/business/server/grpc_server_test.go b/internal/business/server/grpc_server_test.go deleted file mode 100644 index 10440854..00000000 --- a/internal/business/server/grpc_server_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package server - -import ( - "context" - "testing" - "time" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/grpc" -) - -func TestStartGRPCServer_ContextCancellation(t *testing.T) { - t.Run("gracefully shuts down when context is cancelled", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - - cfg := &config.Config{ - GRPC: config.GRPCServer{ - GRPCServer: commoncfg.GRPCServer{ - Address: "localhost:0", // Use port 0 to get a random available port - }, - ShutdownTimeout: 1 * time.Second, - }, - } - - // Create minimal server instances - oidcmappingsrv := grpc.NewOIDCMappingServer(nil) - sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0, "") - - // Start the server in a goroutine - errChan := make(chan error, 1) - go func() { - errChan <- StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) - }() - - // Give the server a moment to start - time.Sleep(100 * time.Millisecond) - - // Cancel the context to trigger shutdown - cancel() - - // Wait for shutdown to complete - select { - case err := <-errChan: - assert.NoError(t, err) - case <-time.After(5 * time.Second): - t.Fatal("Server did not shut down within timeout") - } - }) -} diff --git a/internal/business/server/openapi.go b/internal/business/server/openapi.go index 5bf9752e..4ea4b6db 100644 --- a/internal/business/server/openapi.go +++ b/internal/business/server/openapi.go @@ -16,8 +16,8 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) // sessionManager defines the interface for session management operations diff --git a/internal/business/server/openapi_test.go b/internal/business/server/openapi_test.go index 59fccdd6..79d94706 100644 --- a/internal/business/server/openapi_test.go +++ b/internal/business/server/openapi_test.go @@ -15,8 +15,8 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) const ( diff --git a/internal/cmdutils/cmdutils.go b/internal/cmdutils/cmdutils.go index 2f37ef7f..ebcd1864 100644 --- a/internal/cmdutils/cmdutils.go +++ b/internal/cmdutils/cmdutils.go @@ -3,11 +3,9 @@ package cmdutils import ( "context" "fmt" - "log/slog" "syscall" "time" - "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/openkcm/common-sdk/pkg/health" "github.com/openkcm/common-sdk/pkg/logger" "github.com/openkcm/common-sdk/pkg/otlp" @@ -35,7 +33,11 @@ func CobraCommand( Long: long, SilenceUsage: true, RunE: func(cmd *cobra.Command, _ []string) error { - cfg, err := loadConfig(buildInfo) + cfg, err := config.Load(buildInfo, + "/etc/session-manager/", + "$HOME/.session-manager/", + "./", + ) if err != nil { return fmt.Errorf("loading config: %w", err) } @@ -65,7 +67,7 @@ func run(ctx context.Context, withTelemetry, withStatusServer bool, fn func(cont return oops.In("main"). Wrapf(err, "Failed to initialise the logger") } - slogctx.Debug(ctx, "Starting the application", slog.Any("config", cfg)) + slogctx.Debug(ctx, "Starting the application") // OpenTelemetry if withTelemetry { @@ -95,33 +97,6 @@ func run(ctx context.Context, withTelemetry, withStatusServer bool, fn func(cont return nil } -func loadConfig(buildInfo string) (*config.Config, error) { - defaultValues := map[string]any{} - cfg := &config.Config{} - - err := commoncfg.LoadConfig( - cfg, - defaultValues, - "/etc/session-manager", - "$HOME/.session-manager", - ".", - ) - if err != nil { - return nil, fmt.Errorf("loading configuration: %w", err) - } - - // Update Version - err = commoncfg.UpdateConfigVersion( - &cfg.BaseConfig, - buildInfo, - ) - if err != nil { - return nil, fmt.Errorf("updating the version configuration: %w", err) - } - - return cfg, nil -} - func statusListener(ctx context.Context, state health.State) { subctx := slogctx.With(ctx, "status", state.Status) //nolint:fatcontext @@ -136,11 +111,6 @@ func statusListener(ctx context.Context, state health.State) { } func startStatusServer(ctx context.Context, cfg *config.Config) error { - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return fmt.Errorf("making connection string from config: %w", err) - } - liveness := status.WithLiveness( health.NewHandler( health.NewChecker(health.WithDisabledAutostart()), @@ -150,7 +120,6 @@ func startStatusServer(ctx context.Context, cfg *config.Config) error { healthOptions := []health.Option{ health.WithDisabledAutostart(), health.WithTimeout(healthStatusTimeout), - health.WithDatabaseChecker("pgx", connStr), health.WithStatusListener(statusListener), } @@ -160,8 +129,7 @@ func startStatusServer(ctx context.Context, cfg *config.Config) error { ), ) - err = status.Start(ctx, &cfg.BaseConfig, liveness, readiness) - if err != nil { + if err := status.Start(ctx, &cfg.BaseConfig, liveness, readiness); err != nil { return fmt.Errorf("starting status server: %w", err) } diff --git a/internal/cmdutils/cmdutils_test.go b/internal/cmdutils/cmdutils_test.go index a0cb5a3c..b1f46deb 100644 --- a/internal/cmdutils/cmdutils_test.go +++ b/internal/cmdutils/cmdutils_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "testing" - "time" "github.com/openkcm/common-sdk/pkg/health" "github.com/stretchr/testify/assert" @@ -121,29 +120,6 @@ func TestStatusListener(t *testing.T) { }) } -func TestStartStatusServer(t *testing.T) { - t.Run("returns error when connection string creation fails", func(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Name: "", - Port: "", - }, - } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - err := startStatusServer(ctx, cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") - }) -} - -func TestHealthStatusTimeout(t *testing.T) { - t.Run("has correct value", func(t *testing.T) { - assert.Equal(t, 5*time.Second, healthStatusTimeout) - }) -} - func ExampleCobraCommand() { businessFunc := func(ctx context.Context, cfg *config.Config) error { fmt.Println("Running business logic") diff --git a/internal/config/config.go b/internal/config/config.go index 4d5cd63a..96690291 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,13 +3,19 @@ package config import ( + "fmt" + "reflect" "time" + "github.com/creasty/defaults" + "github.com/knadh/koanf/v2" "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" ) type Config struct { - commoncfg.BaseConfig `mapstructure:",squash" yaml:",inline"` + commoncfg.BaseConfig `yaml:",inline"` HTTP HTTPServer `yaml:"http"` GRPC GRPCServer `yaml:"grpc"` @@ -19,6 +25,94 @@ type Config struct { Migrate Migrate `yaml:"migrate"` SessionManager SessionManager `yaml:"sessionManager"` Housekeeper Housekeeper `yaml:"housekeeper"` + Trust Trust `yaml:"trust"` + Credentials Credentials `yaml:"credentials"` + + // Apps configures long-running components that satisfy the sessionmanager.App + // interface. The map key is an operator-chosen name. Each entry MUST set + // "module:" to the registered module ID; remaining fields are passed to the + // module via UnmarshalExtension. + Apps map[string]*App `yaml:"apps"` + // AppsOrder optionally overrides the start order of apps. Apps not listed + // here are started in parser-defined order after the listed ones. At + // shutdown, apps are stopped in the reverse of the order in which they were + // successfully started. + AppsOrder []string `yaml:"appsOrder"` +} + +// App is the per-entry configuration under the top-level apps: section. It +// implements sessionmanager.ExtensionConfig so it can be passed to LoadApp. +type App struct { + Mod string `yaml:"module"` + Services []*ServiceCfg `yaml:"services"` + koanf *koanf.Koanf +} + +func (c *App) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *App) Module() string { + return c.Mod +} + +func (c *App) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) +} + +// ServiceCfg is a per-entry configuration under an App's services: list. It +// implements sessionmanager.ExtensionConfig so an App's Provision can pass it +// to ctx.LoadModule. +type ServiceCfg struct { + Mod string `yaml:"module"` + koanf *koanf.Koanf +} + +func (c *ServiceCfg) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *ServiceCfg) Module() string { + return c.Mod +} + +func (c *ServiceCfg) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) +} + +type Trust struct { + Mod string `yaml:"module" default:"trust.module.oidc"` + koanf *koanf.Koanf +} + +func (c *Trust) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Trust) Module() string { + return c.Mod +} + +func (c *Trust) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) +} + +func unmarshalExtension(out any, ko *koanf.Koanf) error { + if err := ko.UnmarshalWithConf("", out, koanfUnmarshalConf); err != nil { + return fmt.Errorf("unmarshaling into a structure: %w", err) + } + + setKoanf(reflect.ValueOf(out), ko) + if err := defaults.Set(out); err != nil { + return fmt.Errorf("setting defaults: %w", err) + } + return nil } type Housekeeper struct { @@ -37,25 +131,77 @@ type HTTPServer struct { } type GRPCServer struct { - commoncfg.GRPCServer `mapstructure:",squash" yaml:",inline"` + commoncfg.GRPCServer `yaml:",inline"` ShutdownTimeout time.Duration `yaml:"shutdownTimeout" default:"5s"` } type Database struct { - Name string `yaml:"name"` - Port string `yaml:"port"` - Host commoncfg.SourceRef `yaml:"host"` - User commoncfg.SourceRef `yaml:"user"` - Password commoncfg.SourceRef `yaml:"password"` + Mod string `yaml:"module" default:"database.module.pgxpool"` + + koanf *koanf.Koanf +} + +func (c *Database) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Database) Module() string { + return c.Mod +} + +func (c *Database) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) } type ValKey struct { + Mod string `yaml:"module" default:"sessionstore.module.valkey"` Host commoncfg.SourceRef `yaml:"host"` User commoncfg.SourceRef `yaml:"user"` Password commoncfg.SourceRef `yaml:"password"` Prefix string `yaml:"prefix"` SecretRef commoncfg.SecretRef `yaml:"secretRef"` + + koanf *koanf.Koanf +} + +func (c *ValKey) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *ValKey) Module() string { + return c.Mod +} + +func (c *ValKey) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) +} + +// Credentials is a thin top-level entry whose only purpose is to make the +// credentials module ID swappable. The actual auth-type/secret/mTLS data +// continues to live under sessionManager.clientAuth and the credentials +// module reads it via config.FromContext. +type Credentials struct { + Mod string `yaml:"module" default:"credentials.module.oauth2"` + koanf *koanf.Koanf +} + +func (c *Credentials) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Credentials) Module() string { + return c.Mod +} + +func (c *Credentials) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) } type SessionManager struct { @@ -63,15 +209,10 @@ type SessionManager struct { SessionDuration time.Duration `yaml:"sessionDuration" default:"12h"` // CallbackURL is the URL path for the OAuth2 callback endpoint, where we receive the authorization code. - CallbackURL string `yaml:"callbackURL" default:"/sm/callback"` - ClientAuth ClientAuth `yaml:"clientAuth"` - CSRFSecret commoncfg.SourceRef `yaml:"csrfSecret"` - CSRFSecretParsed []byte `yaml:"-"` - AdditionalQueryParametersAuthorize []string `yaml:"additionalQueryParametersAuthorize"` - AdditionalQueryParametersToken []string `yaml:"additionalQueryParametersToken"` - AdditionalQueryParametersIntrospect []string `yaml:"additionalQueryParametersIntrospect"` - AdditionalQueryParametersLogout []string `yaml:"additionalQueryParametersLogout"` - AdditionalAuthContextKeys []string `yaml:"additionalAuthContextKeys"` + CallbackURL string `yaml:"callbackURL" default:"/sm/callback"` + ClientAuth ClientAuth `yaml:"clientAuth"` + CSRFSecret commoncfg.SourceRef `yaml:"csrfSecret"` + CSRFSecretParsed []byte `yaml:"-"` // SessionCookieTemplate defines the template attributes for the session cookie. SessionCookieTemplate CookieTemplate `yaml:"sessionCookieTemplate"` // CSRFCookieTemplate defines the template attributes for the CSRF cookie. @@ -83,15 +224,6 @@ type SessionManager struct { // during the authorization flow and post logout. This is used to validate the redirect // URLs provided in the authorization request and post logout requests. AllowedRedirectBaseURLs []string `yaml:"allowedRedirectBaseURLs"` - - // Deprecated: use AllowedRedirectBaseURLs instead. - PostLogoutRedirectURL string `yaml:"postLogoutRedirectURL"` - // Deprecated: not used anymore. Kept for a helm issue with the migrate job. - RedirectURL string `yaml:"redirectURL" default:"/sm/redirect"` - // Deprecated: use AdditionalQueryParametersAuthorize instead. - AdditionalGetParametersAuthorize []string `yaml:"additionalGetParametersAuthorize"` - // Deprecated: use AdditionalQueryParametersToken instead. - AdditionalGetParametersToken []string `yaml:"additionalGetParametersToken"` } type CookieSameSiteValue string @@ -126,5 +258,18 @@ type ClientAuth struct { } type Migrate struct { - Source string `yaml:"source" default:"file://./sql"` + Mod string `yaml:"module" default:"trust.migration.module.oidc"` + koanf *koanf.Koanf +} + +func (c *Migrate) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Migrate) Module() string { + return c.Mod +} + +func (c *Migrate) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..a265d5ad --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,158 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// fakeAppModule is used to verify that App.UnmarshalExtension routes +// per-app YAML fields into the target module struct. +type fakeAppModule struct { + TriggerInterval string `yaml:"triggerInterval"` + Endpoint string `yaml:"endpoint"` +} + +func (*fakeAppModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: "config_test.fake.app"} +} + +func TestLoad_AppsSection(t *testing.T) { + yaml := ` +apps: + housekeeper: + module: app.module.housekeeper + triggerInterval: 5m + audit-shipper: + module: app.module.audit-shipper + endpoint: https://example.invalid/audit +appsOrder: + - housekeeper + - audit-shipper +` + + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte(yaml), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + + require.Len(t, cfg.Apps, 2) + require.Contains(t, cfg.Apps, "housekeeper") + require.Contains(t, cfg.Apps, "audit-shipper") + + assert.Equal(t, "app.module.housekeeper", cfg.Apps["housekeeper"].Module()) + assert.Equal(t, "app.module.audit-shipper", cfg.Apps["audit-shipper"].Module()) + assert.Equal(t, []string{"housekeeper", "audit-shipper"}, cfg.AppsOrder) + + // Per-app fields must be reachable via UnmarshalExtension into the target + // module type — confirms the koanf subtree is wired up per entry. + hk := &fakeAppModule{} + require.NoError(t, cfg.Apps["housekeeper"].UnmarshalExtension(hk)) + assert.Equal(t, "5m", hk.TriggerInterval) + + as := &fakeAppModule{} + require.NoError(t, cfg.Apps["audit-shipper"].UnmarshalExtension(as)) + assert.Equal(t, "https://example.invalid/audit", as.Endpoint) +} + +func TestLoad_AppsAbsentIsNoop(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("# empty\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Empty(t, cfg.Apps) + assert.Empty(t, cfg.AppsOrder) +} + +// fakeServiceModule mirrors the per-service YAML fields a gRPC service module +// would expose. Used to assert per-entry koanf subtree wiring under +// apps[].services[]. +type fakeServiceModule struct { + Trust string `yaml:"trust"` + AllowHttpScheme bool `yaml:"allowHttpScheme"` +} + +func (*fakeServiceModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: "config_test.fake.service"} +} + +func TestLoad_ValkeyDefaultModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("valkey:\n prefix: foo\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "sessionstore.module.valkey", cfg.ValKey.Module()) + assert.Equal(t, "foo", cfg.ValKey.Prefix) +} + +func TestLoad_ValkeyCustomModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("valkey:\n module: my.custom.sessionstore\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "my.custom.sessionstore", cfg.ValKey.Module()) +} + +func TestLoad_CredentialsDefaultModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("# empty\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "credentials.module.oauth2", cfg.Credentials.Module()) +} + +func TestLoad_CredentialsCustomModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("credentials:\n module: my.custom.credentials\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "my.custom.credentials", cfg.Credentials.Module()) +} + +func TestLoad_AppServicesPerEntryKoanf(t *testing.T) { + yaml := ` +apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + trust: trust.module.oidc + allowHttpScheme: true + - module: service.module.grpc.trustmapping + trust: trust.module.alt +` + + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte(yaml), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + + require.Contains(t, cfg.Apps, "grpc") + app := cfg.Apps["grpc"] + assert.Equal(t, "app.module.grpcserver", app.Module()) + require.Len(t, app.Services, 2) + + assert.Equal(t, "service.module.grpc.session", app.Services[0].Module()) + assert.Equal(t, "service.module.grpc.trustmapping", app.Services[1].Module()) + + svc0 := &fakeServiceModule{} + require.NoError(t, app.Services[0].UnmarshalExtension(svc0)) + assert.Equal(t, "trust.module.oidc", svc0.Trust) + assert.True(t, svc0.AllowHttpScheme) + + svc1 := &fakeServiceModule{} + require.NoError(t, app.Services[1].UnmarshalExtension(svc1)) + assert.Equal(t, "trust.module.alt", svc1.Trust) +} diff --git a/internal/config/connstr.go b/internal/config/connstr.go deleted file mode 100644 index 5e4f496c..00000000 --- a/internal/config/connstr.go +++ /dev/null @@ -1,27 +0,0 @@ -package config - -import ( - "fmt" - - "github.com/openkcm/common-sdk/pkg/commoncfg" -) - -func MakeConnStr(conf Database) (string, error) { - host, err := commoncfg.LoadValueFromSourceRef(conf.Host) - if err != nil { - return "", fmt.Errorf("loading db host: %w", err) - } - - user, err := commoncfg.LoadValueFromSourceRef(conf.User) - if err != nil { - return "", fmt.Errorf("loading db user: %w", err) - } - - password, err := commoncfg.LoadValueFromSourceRef(conf.Password) - if err != nil { - return "", fmt.Errorf("loading db password: %w", err) - } - - return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s", - host, user, string(password), conf.Name, conf.Port), nil -} diff --git a/internal/config/connstr_test.go b/internal/config/connstr_test.go deleted file mode 100644 index a22dec5e..00000000 --- a/internal/config/connstr_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package config - -import ( - "fmt" - "testing" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" -) - -func TestMakeConnStr(t *testing.T) { - tests := []struct { - name string - conf Database - wantConnStr string - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Make connection string", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "host=my_host user=my_user password=my_password dbname=my_db_name port=5432", - assertErr: assert.NoError, - }, - { - name: "Error - invalid host source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - { - name: "Error - invalid user source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - { - name: "Error - invalid password source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - connStr, err := MakeConnStr(tt.conf) - if !tt.assertErr(t, err, fmt.Sprintf("MakeConnStr() error = %v", err)) || err != nil { - return - } - - assert.Equal(t, tt.wantConnStr, connStr, "MakeConnStr() = %v", connStr) - }) - } -} diff --git a/internal/config/context.go b/internal/config/context.go new file mode 100644 index 00000000..b86e7f7e --- /dev/null +++ b/internal/config/context.go @@ -0,0 +1,26 @@ +package config + +import ( + "context" + + sessionmanager "github.com/openkcm/session-manager" +) + +// configCtxKey is a private type so config attached via WithContext can only +// be retrieved by callers that share this package. +type configCtxKey struct{} + +// WithContext returns a sessionmanager.Context that carries cfg. Apps' +// Provision methods retrieve it via FromContext when they need top-level +// configuration that is not part of their own per-app config block (e.g. +// valkey credentials, audit endpoints). +func WithContext(ctx *sessionmanager.Context, cfg *Config) *sessionmanager.Context { + return ctx.WithValue(configCtxKey{}, cfg) +} + +// FromContext returns the *Config previously attached via WithContext. The +// boolean is false when no config has been attached. +func FromContext(ctx context.Context) (*Config, bool) { + cfg, ok := ctx.Value(configCtxKey{}).(*Config) + return cfg, ok +} diff --git a/internal/config/load.go b/internal/config/load.go new file mode 100644 index 00000000..0d7692e9 --- /dev/null +++ b/internal/config/load.go @@ -0,0 +1,134 @@ +package config + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "unicode" + + "github.com/creasty/defaults" + "github.com/go-viper/mapstructure/v2" + "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/v2" + "github.com/openkcm/common-sdk/pkg/commoncfg" +) + +type koanfSetter interface { + setKoanf(ko *koanf.Koanf) +} + +const configFile = "config.yaml" + +var koanfUnmarshalConf = koanf.UnmarshalConf{ + Tag: "yaml", + DecoderConfig: &mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + Metadata: nil, + WeaklyTypedInput: true, + SquashTagOption: "inline", + }, +} + +func Load(buildInfo string, paths ...string) (*Config, error) { + for i, path := range paths { + paths[i] = filepath.Join(path, configFile) + } + + ko := koanf.New(".") + var loaded bool + for _, path := range paths { + if err := ko.Load(file.Provider(path), yamlParser{}); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("loading configuration from file %s: %w", path, err) + } + } else { + loaded = true + } + } + + if !loaded { + return nil, fmt.Errorf("no config file found at the paths %q: %w", strings.Join(paths, ", "), os.ErrNotExist) + } + + cfg := new(Config) + if err := ko.UnmarshalWithConf("", cfg, koanfUnmarshalConf); err != nil { + return nil, fmt.Errorf("unmarshaling configuration: %w", err) + } + + if err := defaults.Set(cfg); err != nil { + return nil, fmt.Errorf("setting defaults: %w", err) + } + + if buildInfo != "" { + if err := commoncfg.UpdateConfigVersion( + &cfg.BaseConfig, + buildInfo, + ); err != nil { + return nil, fmt.Errorf("updating the version configuration: %w", err) + } + } + + setKoanf(reflect.ValueOf(cfg), ko) + + return cfg, nil +} + +var koanfSetterType = reflect.TypeFor[koanfSetter]() + +func setKoanf(v reflect.Value, ko *koanf.Koanf) { + if v.Type().Implements(koanfSetterType) { + //nolint:forcetypeassert // Checked above + v.Interface().(koanfSetter).setKoanf(ko) + } + + elem := reflect.Indirect(v) + switch elem.Kind() { + case reflect.Struct: + for field, val := range elem.Fields() { + name, _, _ := strings.Cut(field.Tag.Get(koanfUnmarshalConf.Tag), ",") + if name == "" { + runes := []rune(field.Name) + runes[0] = unicode.ToLower(runes[0]) + name = string(runes) + } + + if val.Kind() != reflect.Pointer { + val = val.Addr() + } + + // Slice-of-pointer fields need per-entry sub-koanfs that koanf can + // only construct from the parent path via Slices(); a positional + // Cut() does not work for array indices. + indirect := reflect.Indirect(val) + if indirect.Kind() == reflect.Slice && indirect.Type().Elem().Kind() == reflect.Pointer { + subs := ko.Slices(name) + n := min(indirect.Len(), len(subs)) + for i := range n { + setKoanf(indirect.Index(i), subs[i]) + } + continue + } + + setKoanf(val, ko.Cut(name)) + } + case reflect.Map: + // Recurse into string-keyed map values so each entry receives its own + // koanf subtree. Map values returned by MapRange are not addressable, + // so we only descend when the value type is already a pointer. + if elem.Type().Key().Kind() != reflect.String { + return + } + if elem.Type().Elem().Kind() != reflect.Pointer { + return + } + iter := elem.MapRange() + for iter.Next() { + setKoanf(iter.Value(), ko.Cut(iter.Key().String())) + } + } +} diff --git a/internal/config/parser.go b/internal/config/parser.go new file mode 100644 index 00000000..d3d56cf3 --- /dev/null +++ b/internal/config/parser.go @@ -0,0 +1,23 @@ +package config + +import ( + "github.com/goccy/go-yaml" +) + +// yamlParser implements a yamlParser parser. +type yamlParser struct{} + +// Unmarshal parses the given YAML bytes. +func (yamlParser) Unmarshal(b []byte) (map[string]any, error) { + var out map[string]any + if err := yaml.UnmarshalWithOptions(b, &out); err != nil { + return nil, err + } + + return out, nil +} + +// Marshal marshals the given config map to YAML bytes. +func (yamlParser) Marshal(o map[string]any) ([]byte, error) { + return yaml.Marshal(o) +} diff --git a/internal/dbtest/postgrestest/postgres.go b/internal/dbtest/postgrestest/postgres.go index 5c836f59..6d6e6263 100644 --- a/internal/dbtest/postgrestest/postgres.go +++ b/internal/dbtest/postgrestest/postgres.go @@ -17,7 +17,7 @@ import ( slogctx "github.com/veqryn/slog-context" - migrations "github.com/openkcm/session-manager/sql" + "github.com/openkcm/session-manager/modules/oidctrust/migrations" ) const ( @@ -112,9 +112,9 @@ func prepareDB(ctx context.Context, dbPool *pgxpool.Pool, port network.Port) { migrateDB(ctx, port) b := new(pgx.Batch) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant1-id', false, 'url-one', '', '{}', '{}');`) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant2-id', false, 'url-two', '', '{}', '{}');`) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant3-id', false, 'url-three', '', '{}', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant1-id', false, 'url-one', '', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant2-id', false, 'url-two', '', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant3-id', false, 'url-three', '', '{}');`) res := dbPool.SendBatch(ctx, b) err := res.Close() diff --git a/internal/grpc/oidcmapping.go b/internal/grpc/oidcmapping.go deleted file mode 100644 index 55221284..00000000 --- a/internal/grpc/oidcmapping.go +++ /dev/null @@ -1,129 +0,0 @@ -package grpc - -import ( - "context" - "errors" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" -) - -type OIDCMappingServer struct { - oidcmappingv1.UnimplementedServiceServer - - oidc *trust.Service -} - -func NewOIDCMappingServer(oidc *trust.Service) *OIDCMappingServer { - srv := &OIDCMappingServer{ - oidc: oidc, - } - - return srv -} - -func (srv *OIDCMappingServer) ApplyOIDCMapping(ctx context.Context, req *oidcmappingv1.ApplyOIDCMappingRequest) (*oidcmappingv1.ApplyOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, - "tenantId", req.GetTenantId(), - "issuer", req.GetIssuer(), - "jwksUri", req.GetJwksUri(), - "audiences", req.GetAudiences(), - "properties", req.GetProperties(), - "client_id", req.GetClientId(), - ) - slogctx.Debug(ctx, "ApplyOIDCMapping called") - - response := &oidcmappingv1.ApplyOIDCMappingResponse{} - - mapping := trust.OIDCMapping{ - IssuerURL: req.GetIssuer(), - Blocked: false, - JWKSURI: req.GetJwksUri(), - Audiences: req.GetAudiences(), - Properties: req.GetProperties(), - ClientID: req.GetClientId(), - } - err := srv.oidc.ApplyMapping(ctx, req.GetTenantId(), mapping) - if err != nil { - slogctx.Error(ctx, "Could not apply OIDC mapping", "error", err) - if errors.Is(err, serviceerr.ErrNotFound) { - msg := serviceerr.ErrNotFound.Error() - response.Message = &msg - return response, nil - } - - return nil, status.Errorf(codes.Internal, "failed to apply OIDC mapping: %v", err) - } - - response.Success = true - - return response, nil -} - -// BlockOIDCMapping blocks the OIDC mapping for the specified tenant. -// It calls the underlying service to set the mapping as blocked. -// Returns a response containing an optional error message if blocking fails. -func (srv *OIDCMappingServer) BlockOIDCMapping(ctx context.Context, req *oidcmappingv1.BlockOIDCMappingRequest) (*oidcmappingv1.BlockOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "BlockOIDCMapping called") - - resp := &oidcmappingv1.BlockOIDCMappingResponse{} - err := srv.oidc.BlockMapping(ctx, req.GetTenantId()) - if err != nil { - slogctx.Error(ctx, "Could not block OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to block OIDC mapping: "+msg) - } - resp.Success = true - return resp, nil -} - -// RemoveOIDCMapping removes the OIDC configuration for the tenant. -// It calls the underlying service to remove the mapping. -// Returns a respose containing an optional error message if removing fails. -func (srv *OIDCMappingServer) RemoveOIDCMapping(ctx context.Context, req *oidcmappingv1.RemoveOIDCMappingRequest) (*oidcmappingv1.RemoveOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "RemoveOIDCMapping called") - - resp := &oidcmappingv1.RemoveOIDCMappingResponse{} - err := srv.oidc.RemoveMapping(ctx, req.GetTenantId()) - if err != nil { - if !errors.Is(err, serviceerr.ErrNotFound) { - slogctx.Error(ctx, "Could not remove OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to remove OIDC mapping: "+msg) - } else { - slogctx.Warn(ctx, "RemoveOIDCMapping is called but the tenant does not exist", "error", err) - } - } - - resp.Success = true - return resp, nil -} - -// UnblockOIDCMapping unblocks the OIDC mapping for the specified tenant. -// It calls the underlying service to set the mapping as unblocked. -// Returns a response containing an optional error message if unblocking fails. -func (srv *OIDCMappingServer) UnblockOIDCMapping(ctx context.Context, req *oidcmappingv1.UnblockOIDCMappingRequest) (*oidcmappingv1.UnblockOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "UnblockOIDCMapping called") - - resp := &oidcmappingv1.UnblockOIDCMappingResponse{} - err := srv.oidc.UnblockMapping(ctx, req.GetTenantId()) - if err != nil { - slogctx.Error(ctx, "Could not unblock OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to unblock OIDC mapping: "+msg) - } - resp.Success = true - return resp, nil -} diff --git a/internal/grpc/oidcmapping_test.go b/internal/grpc/oidcmapping_test.go deleted file mode 100644 index 703422b9..00000000 --- a/internal/grpc/oidcmapping_test.go +++ /dev/null @@ -1,421 +0,0 @@ -package grpc_test - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" -) - -func TestNewOIDCMappingServer(t *testing.T) { - t.Run("creates server successfully", func(t *testing.T) { - repo := trustmock.NewInMemRepository() - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - assert.NotNil(t, server) - }) -} - -func TestApplyOIDCMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success - creates new mapping", func(t *testing.T) { - repo := trustmock.NewInMemRepository() - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - jwksUri := "https://issuer.example.com/.well-known/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - Audiences: []string{"audience1", "audience2"}, - Properties: map[string]string{ - "prop1": "value1", - "prop2": "value2", - }, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - assert.Empty(t, resp.GetMessage()) - }) - - t.Run("success - updates existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://old-issuer.example.com", - JWKSURI: "https://old-issuer.example.com/jwks.json", - Audiences: []string{"old-audience"}, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - jwksUri := "https://new-issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://new-issuer.example.com", - JwksUri: &jwksUri, - Audiences: []string{"new-audience"}, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - }) - - t.Run("not found error - returns response with message", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithCreateError(serviceerr.ErrNotFound), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - jwksUri := "https://issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.False(t, resp.GetSuccess()) - require.NotNil(t, resp.GetMessage()) - assert.Equal(t, serviceerr.ErrNotFound.Error(), resp.GetMessage()) - }) - - t.Run("internal error - returns grpc error", func(t *testing.T) { - internalErr := errors.New("database connection failed") - repo := trustmock.NewInMemRepository( - trustmock.WithCreateError(internalErr), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - jwksUri := "https://issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) - - assert.Nil(t, resp) - require.Error(t, err) - - st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to apply OIDC mapping") - }) - - t.Run("update error - returns grpc error", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - } - updateErr := errors.New("update failed") - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - trustmock.WithUpdateError(updateErr), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - jwksUri := "https://new-issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://new-issuer.example.com", - JwksUri: &jwksUri, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) - - assert.Nil(t, resp) - require.Error(t, err) - - st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Internal, st.Code()) - }) -} - -func TestBlockOIDCMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success - blocks existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.BlockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - assert.Empty(t, resp.GetMessage()) - }) - - t.Run("success - already blocked", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.BlockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - }) - - t.Run("not found - returns success", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(serviceerr.ErrNotFound), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.BlockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - }) - - t.Run("error - returns grpc error with message", func(t *testing.T) { - internalErr := errors.New("database error") - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(internalErr), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.BlockOIDCMapping(ctx, req) - - require.Error(t, err) - assert.NotNil(t, resp) - require.NotNil(t, resp.GetMessage()) - assert.Contains(t, resp.GetMessage(), "database error") - - st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to block OIDC mapping") - }) -} - -func TestRemoveOIDCMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success - removes existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.RemoveOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - assert.Empty(t, resp.GetMessage()) - }) - - t.Run("error - returns grpc error with message", func(t *testing.T) { - deleteErr := errors.New("delete failed") - repo := trustmock.NewInMemRepository( - trustmock.WithDeleteError(deleteErr), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.RemoveOIDCMapping(ctx, req) - - require.Error(t, err) - assert.NotNil(t, resp) - require.NotNil(t, resp.GetMessage()) - assert.Contains(t, resp.GetMessage(), "delete failed") - - st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to remove OIDC mapping") - }) - - t.Run("error - delete is indempotent", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithDeleteError(serviceerr.ErrNotFound), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.RemoveOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - assert.Empty(t, resp.GetMessage()) - }) -} - -func TestUnblockOIDCMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success - unblocks blocked mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.UnblockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - assert.Empty(t, resp.GetMessage()) - }) - - t.Run("success - already unblocked", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.UnblockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - }) - - t.Run("not found - returns success", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(serviceerr.ErrNotFound), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.UnblockOIDCMapping(ctx, req) - - require.NoError(t, err) - assert.NotNil(t, resp) - assert.True(t, resp.GetSuccess()) - }) - - t.Run("error - returns grpc error with message", func(t *testing.T) { - internalErr := errors.New("update failed") - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - trustmock.WithUpdateError(internalErr), - ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) - - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } - - resp, err := server.UnblockOIDCMapping(ctx, req) - - require.Error(t, err) - assert.NotNil(t, resp) - require.NotNil(t, resp.GetMessage()) - assert.Contains(t, resp.GetMessage(), "update failed") - - st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to unblock OIDC mapping") - }) -} diff --git a/internal/grpc/options.go b/internal/grpc/options.go deleted file mode 100644 index 6b5909ee..00000000 --- a/internal/grpc/options.go +++ /dev/null @@ -1,23 +0,0 @@ -package grpc - -import "github.com/openkcm/session-manager/internal/credentials" - -type SessionServerOption func(*SessionServer) - -func WithQueryParametersIntrospect(params []string) SessionServerOption { - return func(s *SessionServer) { - s.queryParametersIntrospect = params - } -} - -func WithAllowHttpScheme(allow bool) SessionServerOption { - return func(s *SessionServer) { - s.allowHttpScheme = allow - } -} - -func WithTransportCredentials(b credentials.Builder) SessionServerOption { - return func(s *SessionServer) { - s.newCreds = b - } -} diff --git a/internal/session/housekeeper.go b/internal/session/housekeeper.go index e70cc9c9..dde2297d 100644 --- a/internal/session/housekeeper.go +++ b/internal/session/housekeeper.go @@ -96,12 +96,14 @@ func (m *Manager) housekeepSession(ctx context.Context, s Session, refreshTrigge // refreshAccessToken refreshes the access token for the given session using its refresh token. func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { - mapping, err := m.trustRepo.Get(ctx, s.TenantID) + trust, err := m.trust.Get(ctx, s.TenantID) if err != nil { - return fmt.Errorf("could not get trust mapping: %w", err) + return fmt.Errorf("could not get trust: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return fmt.Errorf("could not get OpenID configuration: %w", err) } @@ -109,12 +111,6 @@ func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { data := url.Values{} data.Set("grant_type", "refresh_token") data.Set("refresh_token", s.RefreshToken) - for _, parameter := range m.queryParametersToken { - value, ok := mapping.Properties[parameter] - if ok { - data.Set(parameter, value) - } - } req, err := http.NewRequestWithContext(ctx, http.MethodPost, openidConf.TokenEndpoint, bytes.NewBufferString(data.Encode())) if err != nil { @@ -122,7 +118,7 @@ func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := m.httpClient(mapping) + client := m.httpClient(oidc) resp, err := client.Do(req) if err != nil { return err diff --git a/internal/session/housekeeper_test.go b/internal/session/housekeeper_test.go index e466a63b..9ee88f28 100644 --- a/internal/session/housekeeper_test.go +++ b/internal/session/housekeeper_test.go @@ -2,6 +2,7 @@ package session_test import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -10,12 +11,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" ) func TestDeleteIdleSessions(t *testing.T) { @@ -92,7 +95,6 @@ func TestRefreshAccessToken(t *testing.T) { assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) assert.Equal(t, "old-refresh-token", r.Form.Get("refresh_token")) assert.Equal(t, "test-client-id", r.Form.Get("client_id")) - assert.Equal(t, "param-value", r.Form.Get("test-param")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ @@ -104,14 +106,15 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServerURL, - Properties: map[string]string{ - "test-param": "param-value", - }, - } + trustData := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(oidcRepo) sess := session.Session{ ID: sessionID, @@ -130,13 +133,12 @@ func TestRefreshAccessToken(t *testing.T) { ClientAuth: config.ClientAuth{ ClientID: "test-client-id", }, - AdditionalQueryParametersToken: []string{"test-param"}, - CSRFSecretParsed: []byte(testCSRFSecret), + CSRFSecretParsed: []byte(testCSRFSecret), } manager, err := session.NewManager(ctx, cfg, - oidcRepo, + trust, sessions, nil, session.WithAllowHttpScheme(true), @@ -154,8 +156,9 @@ func TestRefreshAccessToken(t *testing.T) { assert.Equal(t, "new-refresh-token", updatedSess.RefreshToken) }) - t.Run("Error - trust mapping not found", func(t *testing.T) { - oidcRepo := trustmock.NewInMemRepository() + t.Run("Error - trust not found", func(t *testing.T) { + oidcRepo := mocktrust.NewInMemRepository() + trust := newTrust(oidcRepo) sess := session.Session{ ID: sessionID, @@ -176,7 +179,7 @@ func TestRefreshAccessToken(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - manager, err := session.NewManager(ctx, cfg, oidcRepo, sessions, nil) + manager, err := session.NewManager(ctx, cfg, trust, sessions, nil) require.NoError(t, err) // Trigger housekeeping - should log error but not fail @@ -205,12 +208,14 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServerURL, - Properties: map[string]string{}, - } + trustData := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) sess := session.Session{ ID: sessionID, @@ -234,7 +239,7 @@ func TestRefreshAccessToken(t *testing.T) { manager, err := session.NewManager(ctx, cfg, - oidcRepo, + newTrust(oidcRepo), sessions, nil, session.WithAllowHttpScheme(true), @@ -263,12 +268,14 @@ func TestRefreshAccessToken(t *testing.T) { defer discoveryServer.Close() discoveryServerURL = discoveryServer.URL - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServer.URL, - Properties: map[string]string{}, // Missing required parameter - } + trustData := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServer.URL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) sess := session.Session{ ID: sessionID, @@ -286,13 +293,12 @@ func TestRefreshAccessToken(t *testing.T) { ClientAuth: config.ClientAuth{ ClientID: "test-client-id", }, - AdditionalQueryParametersToken: []string{"missing-param"}, - CSRFSecretParsed: []byte(testCSRFSecret), + CSRFSecretParsed: []byte(testCSRFSecret), } manager, err := session.NewManager(ctx, cfg, - oidcRepo, + newTrust(oidcRepo), sessions, nil, session.WithAllowHttpScheme(true), @@ -367,4 +373,135 @@ func TestHousekeepSession_ErrorCases(t *testing.T) { require.NoError(t, err) assert.Equal(t, "original-token", updatedSess.AccessToken) }) + + t.Run("IsActive returns error — session is skipped, housekeeping does not fail", func(t *testing.T) { + sessionID := "test-session-id-isactive-err" + + sess := session.Session{ + ID: sessionID, + TenantID: "test-tenant", + AccessTokenExpiry: time.Now().Add(2 * time.Hour), + Expiry: time.Now().Add(2 * time.Hour), + } + + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithIsActiveError(errors.New("valkey unavailable")), + ) + + cfg := &config.SessionManager{ + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, nil, sessions, nil) + require.NoError(t, err) + + // Housekeeping must not surface the IsActive error. + err = manager.TriggerHousekeeping(ctx, 1, time.Hour) + require.NoError(t, err) + + // Session must still exist — no deletion was attempted. + _, err = sessions.LoadSession(ctx, sessionID) + require.NoError(t, err) + }) + + t.Run("DeleteSession returns error — housekeeping continues without failing", func(t *testing.T) { + sessionID := "test-session-id-delete-err" + + sess := session.Session{ + ID: sessionID, + TenantID: "test-tenant", + AccessTokenExpiry: time.Now().Add(2 * time.Hour), + Expiry: time.Now().Add(2 * time.Hour), + } + + // Inactive session (no BumpActive) + DeleteSession always errors. + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithDeleteSessionError(errors.New("delete failed")), + ) + + cfg := &config.SessionManager{ + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, nil, sessions, nil) + require.NoError(t, err) + + // Housekeeping must not surface the DeleteSession error. + err = manager.TriggerHousekeeping(ctx, 1, time.Hour) + require.NoError(t, err) + }) + + t.Run("StoreSession error during token refresh — housekeeping continues without failing", func(t *testing.T) { + tenantID := "test-tenant-store-err" + sessionID := "test-session-id-store-err" + + var discoveryServerURL string + discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": discoveryServerURL, + "token_endpoint": discoveryServerURL + "/token", + }) + })) + defer discoveryServer.Close() + discoveryServerURL = discoveryServer.URL + + // Token server returns a valid refresh response. + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + trustData := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() + + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(oidcRepo) + + sess := session.Session{ + ID: sessionID, + TenantID: tenantID, + RefreshToken: "old-refresh-token", + AccessToken: "old-access-token", + AccessTokenExpiry: time.Now().Add(10 * time.Second), // Near expiry. + Expiry: time.Now().Add(1 * time.Hour), + } + + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithStoreSessionError(errors.New("store failed")), + ) + err := sessions.BumpActive(ctx, sessionID, time.Hour) + require.NoError(t, err) + + cfg := &config.SessionManager{ + ClientAuth: config.ClientAuth{ClientID: "client-id"}, + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, trust, sessions, nil, + session.WithAllowHttpScheme(true), + ) + require.NoError(t, err) + + // Housekeeping must not surface the StoreSession error. + err = manager.TriggerHousekeeping(ctx, 1, 1*time.Minute) + require.NoError(t, err) + + // Access token must remain unchanged — store failed. + updatedSess, err := sessions.LoadSession(ctx, sessionID) + require.NoError(t, err) + assert.Equal(t, "old-access-token", updatedSess.AccessToken) + }) } diff --git a/internal/session/import_test.go b/internal/session/import_test.go new file mode 100644 index 00000000..cfe5ca6f --- /dev/null +++ b/internal/session/import_test.go @@ -0,0 +1,12 @@ +package session_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/internal/session/manager.go b/internal/session/manager.go index 12c2e124..455696d7 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -19,16 +19,19 @@ import ( "github.com/jellydator/ttlcache/v3" "github.com/openkcm/common-sdk/pkg/csrf" "github.com/openkcm/common-sdk/pkg/oidc" + "google.golang.org/protobuf/proto" + flowv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/flow/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/debugtools" "github.com/openkcm/session-manager/internal/pkce" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" + "github.com/openkcm/session-manager/pkg/serviceerr" ) const defaultWKOCCacheExpiration = 30 * time.Minute @@ -40,20 +43,16 @@ const ( ) type Manager struct { - trustRepo trust.OIDCMappingRepository - sessions Repository - pkce pkce.Source - audit *otlpaudit.AuditLogger - newCreds credentials.Builder - - sessionDuration time.Duration - idleSessionTimeout time.Duration - callbackURL *url.URL - clientID string - queryParametersAuth []string - queryParametersToken []string - authContextKeys []string - queryParametersLogout []string + trust sessionmanager.Trust + sessions Repository + pkce pkce.Source + audit *otlpaudit.AuditLogger + newCreds credentials.Builder + + sessionDuration time.Duration + idleSessionTimeout time.Duration + callbackURL *url.URL + clientID string sessionCookieTemplate config.CookieTemplate csrfCookieTemplate config.CookieTemplate @@ -70,7 +69,7 @@ type Manager struct { func NewManager( ctx context.Context, cfg *config.SessionManager, - trustRepo trust.OIDCMappingRepository, + trust sessionmanager.Trust, sessionsRepo Repository, auditLogger *otlpaudit.AuditLogger, opts ...ManagerOption, @@ -81,15 +80,11 @@ func NewManager( } m := &Manager{ - trustRepo: trustRepo, + trust: trust, sessions: sessionsRepo, audit: auditLogger, sessionDuration: cfg.SessionDuration, idleSessionTimeout: cfg.IdleSessionTimeout, - queryParametersAuth: cfg.AdditionalQueryParametersAuthorize, - queryParametersToken: cfg.AdditionalQueryParametersToken, - authContextKeys: cfg.AdditionalAuthContextKeys, - queryParametersLogout: cfg.AdditionalQueryParametersLogout, sessionCookieTemplate: cfg.SessionCookieTemplate, csrfCookieTemplate: cfg.CSRFCookieTemplate, loginCSRFCookieTemplate: cfg.LoginCSRFCookieTemplate, @@ -117,12 +112,14 @@ func NewManager( // MakeAuthURI returns an OIDC authentication URI. func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, requestURI string) (string, string, error) { - mapping, err := m.trustRepo.Get(ctx, tenantID) + trust, err := m.trust.Get(ctx, tenantID) if err != nil { - return "", "", fmt.Errorf("getting trust mapping: %w", err) + return "", "", fmt.Errorf("getting trust: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return "", "", fmt.Errorf("getting an openid config: %w", err) } @@ -146,7 +143,7 @@ func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, reques return "", "", fmt.Errorf("storing session: %w", err) } - u, err := m.authURI(openidConf, state, pkce, mapping) + u, err := m.authURI(openidConf, state, pkce, oidc) if err != nil { return "", "", fmt.Errorf("generating auth uri: %w", err) } @@ -158,7 +155,7 @@ func (m *Manager) LoadState(ctx context.Context, stateID string) (State, error) return m.sessions.LoadState(ctx, stateID) } -func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce.PKCE, mapping trust.OIDCMapping) (string, error) { +func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce.PKCE, oidc *oidcv1.OIDC) (string, error) { u, err := url.Parse(openidConf.AuthorizationEndpoint) if err != nil { return "", fmt.Errorf("parsing authorisation endpoint url: %w", err) @@ -167,16 +164,15 @@ func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce q := u.Query() q.Set("scope", "openid profile email groups") q.Set("response_type", "code") - q.Set("client_id", m.getClientID(mapping)) + q.Set("client_id", m.getClientID(oidc)) q.Set("state", state.ID) q.Set("code_challenge", pkce.Challenge) q.Set("code_challenge_method", pkce.Method) q.Set("redirect_uri", m.callbackURL.String()) - for _, parameter := range m.queryParametersAuth { - value, ok := mapping.Properties[parameter] - if ok { - q.Set(parameter, value) - } + + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_AuthAttributes).([]*flowv1.Attribute) { + q.Set(param.GetKey(), param.GetValue()) } u.RawQuery = q.Encode() @@ -230,19 +226,21 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr return OIDCSessionData{}, serviceerr.ErrFingerprintMismatch } - mapping, err := m.trustRepo.Get(ctx, state.TenantID) + trust, err := m.trust.Get(ctx, state.TenantID) if err != nil { - m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get trust mapping") - return OIDCSessionData{}, fmt.Errorf("getting trust mapping: %w", err) + m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get trust") + return OIDCSessionData{}, fmt.Errorf("getting trust: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get openid configuration") return OIDCSessionData{}, fmt.Errorf("getting openid configuration: %w", err) } - tokens, err := m.exchangeCode(ctx, openidConf, code, state.PKCEVerifier, mapping) + tokens, err := m.exchangeCode(ctx, openidConf, code, state.PKCEVerifier, oidc) if err != nil { m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to exchange code for tokens") return OIDCSessionData{}, fmt.Errorf("exchanging code for tokens: %w", err) @@ -299,14 +297,13 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr // prepare the auth context used by ExtAuthZ authContext := map[string]string{ - "issuer": mapping.IssuerURL, - "client_id": m.getClientID(mapping), + "issuer": oidc.GetIssuer(), + "client_id": m.getClientID(oidc), } - for _, parameter := range m.authContextKeys { - value, ok := mapping.Properties[parameter] - if ok { - authContext[parameter] = value - } + + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_AuthContext).([]*flowv1.Attribute) { + authContext[param.GetKey()] = param.GetValue() } session := Session{ @@ -315,7 +312,7 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr ProviderID: customClaims.SID, Fingerprint: fingerprint, CSRFToken: csrfToken, - Issuer: mapping.IssuerURL, + Issuer: oidc.GetIssuer(), Claims: Claims{ Subject: standardClaims.Subject, UserUUID: customClaims.UserUUID, @@ -376,15 +373,16 @@ func (m *Manager) Logout(ctx context.Context, sessionID, postLogoutRedirectURL s ctx = slogctx.With(ctx, "tenantId", session.TenantID) - mapping, err := m.trustRepo.Get(ctx, session.TenantID) + trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { - slogctx.Error(ctx, "failed to get trust mapping for a tenant", "error", err) - return "", fmt.Errorf("getting trust mapping: %w", err) + slogctx.Error(ctx, "failed to get trust for a tenant", "error", err) + return "", fmt.Errorf("getting trust: %w", err) } - ctx = slogctx.With(ctx, "issuerUrl", mapping.IssuerURL) + oidc := trust.GetOidc() + ctx = slogctx.With(ctx, "issuer", oidc.GetIssuer()) - oidcConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidcConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { slogctx.Warn(ctx, "failed to get oidc configuration", "error", err) return "", fmt.Errorf("getting oidc configuration: %w", err) @@ -407,14 +405,12 @@ func (m *Manager) Logout(ctx context.Context, sessionID, postLogoutRedirectURL s } vals := make(url.Values, 2) - vals.Set("client_id", m.getClientID(mapping)) + vals.Set("client_id", m.getClientID(oidc)) vals.Set("post_logout_redirect_uri", postLogoutRedirectURL) - for _, parameter := range m.queryParametersLogout { - value, ok := mapping.Properties[parameter] - if ok { - vals.Set(parameter, value) - } + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_LogoutAttributes).([]*flowv1.Attribute) { + vals.Set(param.GetKey(), param.GetValue()) } redirectURL.RawQuery = vals.Encode() @@ -473,12 +469,14 @@ func (m *Manager) BCLogout(ctx context.Context, logoutJWT string) error { return nil } - mapping, err := m.trustRepo.Get(ctx, session.TenantID) + trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { - return fmt.Errorf("getting trust mapping: %w", err) + return fmt.Errorf("getting trust: %w", err) } - oidcConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + oidcConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return fmt.Errorf("getting oidc config: %w", err) } @@ -611,8 +609,8 @@ func (m *Manager) verifyAccessToken(accessToken, atHash string, idToken *jwt.JSO return nil } -func (m *Manager) httpClient(mapping trust.OIDCMapping) *http.Client { - creds := m.newCreds(m.getClientID(mapping)) +func (m *Manager) httpClient(oidc *oidcv1.OIDC) *http.Client { + creds := m.newCreds(m.getClientID(oidc)) transport := creds.Transport() if debugSettingSMDumpTransport.Value() == "1" { transport = debugtools.NewTransport(transport) @@ -623,25 +621,23 @@ func (m *Manager) httpClient(mapping trust.OIDCMapping) *http.Client { } } -func (m *Manager) getClientID(mapping trust.OIDCMapping) string { - if mapping.ClientID != "" { - return mapping.ClientID +func (m *Manager) getClientID(oidc *oidcv1.OIDC) string { + if clientID := oidc.GetClientId(); clientID != "" { + return clientID } return m.clientID } -func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configuration, code, codeVerifier string, mapping trust.OIDCMapping) (tokenResponse, error) { +func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configuration, code, codeVerifier string, oidc *oidcv1.OIDC) (tokenResponse, error) { data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("code", code) data.Set("code_verifier", codeVerifier) data.Set("redirect_uri", m.callbackURL.String()) - for _, parameter := range m.queryParametersToken { - value, ok := mapping.Properties[parameter] - if ok { - data.Set(parameter, value) - } + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_TokenAttributes).([]*flowv1.Attribute) { + data.Set(param.GetKey(), param.GetValue()) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, openidConf.TokenEndpoint, strings.NewReader(data.Encode())) @@ -650,7 +646,7 @@ func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configurati } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := m.httpClient(mapping) + client := m.httpClient(oidc) resp, err := client.Do(req) if err != nil { return tokenResponse{}, fmt.Errorf("executing request: %w", err) diff --git a/internal/session/manager_cookie_test.go b/internal/session/manager_cookie_test.go index 9927f821..5fac7436 100644 --- a/internal/session/manager_cookie_test.go +++ b/internal/session/manager_cookie_test.go @@ -7,11 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) func TestManager_MakeSessionCookie(t *testing.T) { @@ -426,33 +428,35 @@ func TestManager_Logout(t *testing.T) { tests := []struct { name string cfg *config.SessionManager - setupOIDCRepo func(t *testing.T) *trustmock.Repository + setupOIDCRepo func(t *testing.T) *mocktrust.Repository setupSession func(*sessionmock.Repository) wantErr bool wantErrMessage string wantURL string }{ { - name: "Success - redirect to postLogoutURL when no end session endpoint", + name: "Success - redirect to postLogoutURL", cfg: &config.SessionManager{ CSRFSecretParsed: []byte(testCSRFSecret), ClientAuth: config.ClientAuth{ ClientID: testClientID, }, }, - setupOIDCRepo: func(t *testing.T) *trustmock.Repository { + setupOIDCRepo: func(t *testing.T) *mocktrust.Repository { t.Helper() // StartOIDCServer doesn't include end_session_endpoint, so it will fall back to postLogoutURL oidcServer := StartOIDCServer(t, false) t.Cleanup(oidcServer.Close) - mapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - JWKSURI: oidcServer.URL + "/jwks", - Audiences: []string{"test"}, - Properties: map[string]string{}, - } - return trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + trust := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new(oidcServer.URL + "/jwks"), + Audiences: []string{"test"}, + }.Build(), + }.Build() + return mocktrust.NewInMemRepository(mocktrust.WithTrust(trust)) }, setupSession: func(repo *sessionmock.Repository) { //nolint:errcheck @@ -467,15 +471,14 @@ func TestManager_Logout(t *testing.T) { { name: "Error - session not found", cfg: &config.SessionManager{ - CSRFSecretParsed: []byte(testCSRFSecret), - PostLogoutRedirectURL: postLogoutURL, + CSRFSecretParsed: []byte(testCSRFSecret), ClientAuth: config.ClientAuth{ ClientID: testClientID, }, }, - setupOIDCRepo: func(t *testing.T) *trustmock.Repository { + setupOIDCRepo: func(t *testing.T) *mocktrust.Repository { t.Helper() - return trustmock.NewInMemRepository() + return mocktrust.NewInMemRepository() }, setupSession: func(repo *sessionmock.Repository) { // Don't store session - will cause error @@ -491,14 +494,9 @@ func TestManager_Logout(t *testing.T) { tt.setupSession(sessionRepo) oidcRepo := tt.setupOIDCRepo(t) + trust := newTrust(oidcRepo) - m, err := session.NewManager(ctx, - tt.cfg, - oidcRepo, - sessionRepo, - nil, - session.WithAllowHttpScheme(true), - ) + m, err := session.NewManager(ctx, tt.cfg, trust, sessionRepo, nil, session.WithAllowHttpScheme(true)) require.NoError(t, err) logoutURL, err := m.Logout(t.Context(), sessionID, postLogoutURL) diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go index cae05e7b..f37866b6 100644 --- a/internal/session/manager_test.go +++ b/internal/session/manager_test.go @@ -20,14 +20,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) const ( @@ -49,20 +50,19 @@ func TestManager_Auth(t *testing.T) { auditServer := StartAuditServer(t) defer auditServer.Close() - oidcMapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - Blocked: false, - JWKSURI: "http://jwks.example.com", - Audiences: []string{requestURI}, - Properties: map[string]string{ - "paramAuth1": "paramAuth1", - "paramToken1": "paramToken1", - }, - } + oidcTrust := trustv1.Trust_builder{ + TenantId: new(tenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new("http://jwks.example.com"), + Audiences: []string{requestURI}, + }.Build(), + }.Build() tests := []struct { name string - oidc *trustmock.Repository + oidc *mocktrust.Repository sessions *sessionmock.Repository requestURI string cfg *config.SessionManager @@ -70,17 +70,16 @@ func TestManager_Auth(t *testing.T) { fingerprint string wantURL string errAssert assert.ErrorAssertionFunc - mapping trust.OIDCMapping + trust *trustv1.Trust }{ { name: "Success", - oidc: trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcTrust)), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, cfg: &config.SessionManager{ - SessionDuration: time.Hour, - CallbackURL: callbackURL, - AdditionalQueryParametersAuthorize: []string{"paramAuth1"}, + SessionDuration: time.Hour, + CallbackURL: callbackURL, ClientAuth: config.ClientAuth{ ClientID: testClientID, }, @@ -88,14 +87,14 @@ func TestManager_Auth(t *testing.T) { }, tenantID: tenantID, fingerprint: "fingerprint", - wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256&redirect_uri=" + callbackURL + "&response_type=code&scope=openid+profile+email+groups&state=someState¶mAuth1=paramAuth1", + wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256&redirect_uri=" + callbackURL + "&response_type=code&scope=openid+profile+email+groups&state=someState", errAssert: assert.NoError, }, { - name: "Get trust mapping error", - oidc: trustmock.NewInMemRepository( - trustmock.WithTrust(tenantID, oidcMapping), - trustmock.WithGetError(errors.New("failed to get trust mapping")), + name: "Get trust error", + oidc: mocktrust.NewInMemRepository( + mocktrust.WithTrust(oidcTrust), + mocktrust.WithGetError(errors.New("failed to get trust")), ), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, @@ -111,7 +110,7 @@ func TestManager_Auth(t *testing.T) { }, { name: "Save state error", - oidc: trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcTrust)), sessions: sessionmock.NewInMemRepository(sessionmock.WithStoreStateError(errors.New("failed to save state"))), requestURI: requestURI, cfg: &config.SessionManager{ @@ -143,7 +142,7 @@ func TestManager_Auth(t *testing.T) { m, err := session.NewManager(ctx, tt.cfg, - tt.oidc, + newTrust(tt.oidc), tt.sessions, auditLogger, session.WithAllowHttpScheme(true), @@ -156,7 +155,7 @@ func TestManager_Auth(t *testing.T) { } // Validate that the data has been inserted into the repository - assert.Equal(t, oidcMapping, tt.oidc.TGet(tt.tenantID), "Trust mapping has not been inserted") + assert.Equal(t, oidcTrust, tt.oidc.TGet(tt.tenantID), "Trust has not been inserted") // Check the returned URL u, err := url.Parse(got) @@ -232,7 +231,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { tests := []struct { name string - oidc *trustmock.Repository + oidc *mocktrust.Repository sessions *sessionmock.Repository stateID string code string @@ -246,17 +245,15 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }{ { name: "Success", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, fingerprint: fingerprint, cfg: &config.SessionManager{ - SessionDuration: time.Hour, - CallbackURL: callbackURL, - AdditionalQueryParametersToken: []string{"queryParamToken1"}, - AdditionalAuthContextKeys: []string{"authContextKey1"}, - CSRFSecretParsed: []byte(testCSRFSecret), + SessionDuration: time.Hour, + CallbackURL: callbackURL, + CSRFSecretParsed: []byte(testCSRFSecret), }, wantSessionID: true, wantCSRFToken: true, @@ -265,7 +262,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "State load error", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithLoadStateError(errors.New("state not found"))), stateID: stateID, code: code, @@ -281,7 +278,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "State expired", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(expiredState)), stateID: stateID, code: code, @@ -296,7 +293,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "Fingerprint mismatch", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(mismatchState)), stateID: stateID, code: code, @@ -310,8 +307,8 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { errAssert: assert.Error, }, { - name: "Trust mapping get error", - oidc: trustmock.NewInMemRepository(trustmock.WithGetError(errors.New("trust mapping not found"))), + name: "Trust get error", + oidc: mocktrust.NewInMemRepository(mocktrust.WithGetError(errors.New("trust not found"))), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, @@ -326,7 +323,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "Token exchange error", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, @@ -357,22 +354,21 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { jwksURI, err := url.JoinPath(oidcServer.URL, "/.well-known/jwks.json") require.NoError(t, err) - localOIDCMapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - Properties: map[string]string{ - "queryParamToken1": "queryParamToken1", - "authContextKey1": "authContextValue1", - }, - } + localOIDCTrust := trustv1.Trust_builder{ + TenantId: new(tenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() - tt.oidc.TAdd(tenantID, localOIDCMapping) + tt.oidc.TAdd(localOIDCTrust) m, err := session.NewManager(ctx, tt.cfg, - tt.oidc, + newTrust(tt.oidc), tt.sessions, auditLogger, session.WithAllowHttpScheme(true), @@ -464,7 +460,7 @@ func TestManager_BCLogout(t *testing.T) { name string cfg *config.SessionManager jwt string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { @@ -478,10 +474,13 @@ func TestManager_BCLogout(t *testing.T) { Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, SessionID: "sid-1", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { - _ = oidcs.Create(context.Background(), "tid-1", trust.OIDCMapping{ - IssuerURL: jwksSrv.URL, - }) + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { + _ = oidcs.Create(context.Background(), trustv1.Trust_builder{ + TenantId: new("tid-1"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(jwksSrv.URL), + }.Build(), + }.Build()) _ = sessions.StoreSession(context.Background(), session.Session{ID: "sid-1", TenantID: "tid-1"}) }, errAssert: assert.NoError, @@ -499,7 +498,8 @@ func TestManager_BCLogout(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() rt := localRoundTripper{ @@ -522,7 +522,7 @@ func TestManager_BCLogout(t *testing.T) { m, err := session.NewManager(ctx, tt.cfg, - oidcMock, + trust, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt)), @@ -547,13 +547,13 @@ func TestManager_LogoutEdgeCases(t *testing.T) { tests := []struct { name string sessionID string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { name: "Session not found", sessionID: "non-existent", - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { _ = sessions.StoreSession(context.Background(), session.Session{ ID: sessionID, TenantID: tenantID, @@ -562,9 +562,9 @@ func TestManager_LogoutEdgeCases(t *testing.T) { errAssert: assert.Error, }, { - name: "Trust mapping not found", + name: "Trust not found", sessionID: sessionID, - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { _ = sessions.StoreSession(context.Background(), session.Session{ ID: sessionID, TenantID: tenantID, @@ -584,7 +584,8 @@ func TestManager_LogoutEdgeCases(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() tt.setupMock(oidcMock, sessionMock) @@ -596,7 +597,7 @@ func TestManager_LogoutEdgeCases(t *testing.T) { }, } - m, err := session.NewManager(ctx, cfg, oidcMock, sessionMock, auditLogger) + m, err := session.NewManager(ctx, cfg, trust, sessionMock, auditLogger) require.NoError(t, err) _, err = m.Logout(ctx, tt.sessionID, postLogoutURL) @@ -644,13 +645,13 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { tests := []struct { name string jwt string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { name: "Invalid JWT", jwt: "invalid.jwt.token", - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -663,7 +664,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { Events: map[string]struct{}{"http://invalid-event": {}}, SessionID: "sid-1", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -674,7 +675,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { }{ Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -687,7 +688,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, SessionID: "non-existent-session", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.NoError, }, @@ -703,7 +704,8 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() rt := localRoundTripper{ @@ -722,7 +724,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - m, err := session.NewManager(ctx, cfg, oidcMock, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt))) + m, err := session.NewManager(ctx, cfg, trust, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt))) require.NoError(t, err) err = m.BCLogout(ctx, tt.jwt) @@ -744,7 +746,9 @@ func TestManager_NewManager_Error(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - m, err := session.NewManager(ctx, cfg, trustmock.NewInMemRepository(), sessionmock.NewInMemRepository(), auditLogger) + trust := newTrust(mocktrust.NewInMemRepository()) + + m, err := session.NewManager(ctx, cfg, trust, sessionmock.NewInMemRepository(), auditLogger) assert.Error(t, err) assert.Nil(t, m) assert.Contains(t, err.Error(), "parsing callback URL") diff --git a/internal/session/mock/repository.go b/internal/session/mock/repository.go index 71c85a31..459dc75e 100644 --- a/internal/session/mock/repository.go +++ b/internal/session/mock/repository.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type RepositoryOption func(*Repository) diff --git a/internal/session/valkey/repository.go b/internal/session/valkey/repository.go index 58ef4cf5..92050b35 100644 --- a/internal/session/valkey/repository.go +++ b/internal/session/valkey/repository.go @@ -10,8 +10,8 @@ import ( slogctx "github.com/veqryn/slog-context" - "github.com/openkcm/session-manager/internal/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type ObjectType string diff --git a/internal/session/valkey/store.go b/internal/session/valkey/store.go index 1750951b..b92622ce 100644 --- a/internal/session/valkey/store.go +++ b/internal/session/valkey/store.go @@ -11,7 +11,7 @@ import ( "github.com/valkey-io/valkey-go" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type store struct { diff --git a/internal/sessionwiring/sessionwiring.go b/internal/sessionwiring/sessionwiring.go new file mode 100644 index 00000000..935f92af --- /dev/null +++ b/internal/sessionwiring/sessionwiring.go @@ -0,0 +1,93 @@ +// Package sessionwiring centralises the construction of the long-lived +// session.Manager that the HTTP API server and the housekeeper subcommand +// share. The Valkey-backed session.Repository and OAuth2 credentials.Builder +// it needs are no longer built here; both come from the module registry, +// loaded by business.Main (or the housekeeper subcommand) before this is +// invoked. +package sessionwiring + +import ( + "context" + "fmt" + + otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" + "github.com/openkcm/session-manager/internal/session" +) + +// credentialsBuilder is the interface satisfied by a credentials module +// (e.g. credentials.module.oauth2). Defined locally so this package does not +// need to import the credentials module. +type credentialsBuilder interface { + Builder() credentials.Builder +} + +// InitSessionManager builds a session.Manager from the supplied config and +// trust module, using session repository and credential modules already loaded +// in ctx. The returned closeFn is a no-op kept for API compatibility — the +// underlying valkey client is owned by the sessionstore module and closed by +// the framework's reverse-load-order shutdown. +func InitSessionManager(ctx *sessionmanager.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { + repo, err := SessionRepository(ctx, cfg) + if err != nil { + return nil, nil, fmt.Errorf("getting session repository: %w", err) + } + + credsBuilder, err := CredsBuilder(ctx, cfg) + if err != nil { + return nil, nil, fmt.Errorf("getting credentials builder: %w", err) + } + + auditLogger, err := otlpaudit.NewLogger(&cfg.Audit) + if err != nil { + return nil, nil, fmt.Errorf("failed to create audit logger: %w", err) + } + + sessManager, err := session.NewManager(ctx, + &cfg.SessionManager, + trust, + repo, + auditLogger, + session.WithTransportCredentials(credsBuilder), + ) + if err != nil { + return nil, nil, fmt.Errorf("failed to create session manager: %w", err) + } + + return sessManager, func() {}, nil +} + +// SessionRepository resolves the session repository module loaded under the +// ID configured in cfg.ValKey.Module() and returns its session.Repository. +func SessionRepository(ctx *sessionmanager.Context, cfg *config.Config) (session.Repository, error) { + mod, err := ctx.GetModule(cfg.ValKey.Module()) + if err != nil { + return nil, fmt.Errorf("getting session-store module %q: %w", cfg.ValKey.Module(), err) + } + repo, ok := mod.(session.Repository) + if !ok { + return nil, fmt.Errorf("module %q does not implement session.Repository", cfg.ValKey.Module()) + } + return repo, nil +} + +// CredsBuilder resolves the credentials module loaded under the ID configured +// in cfg.Credentials.Module() and returns its credentials.Builder. +func CredsBuilder(ctx *sessionmanager.Context, cfg *config.Config) (credentials.Builder, error) { + mod, err := ctx.GetModule(cfg.Credentials.Module()) + if err != nil { + return nil, fmt.Errorf("getting credentials module %q: %w", cfg.Credentials.Module(), err) + } + cb, ok := mod.(credentialsBuilder) + if !ok { + return nil, fmt.Errorf("module %q does not expose Builder()", cfg.Credentials.Module()) + } + return cb.Builder(), nil +} + +// Reference to context.Context to keep imports stable for callers using +// (ctx context.Context) signatures. +var _ context.Context = (*sessionmanager.Context)(nil) diff --git a/internal/trust/mapping.go b/internal/trust/mapping.go deleted file mode 100644 index b273807e..00000000 --- a/internal/trust/mapping.go +++ /dev/null @@ -1,25 +0,0 @@ -package trust - -type OIDCMapping struct { - IssuerURL string - Blocked bool - JWKSURI string - Audiences []string - Properties map[string]string - - // ClientID is a client_id property used for authentication. - // It is an optional value for the trust config. If the trust's client id is not specified, - // the application-global client id is used. - ClientID string -} - -func (p *OIDCMapping) GetIntrospectParameters(keys []string) map[string]string { - params := make(map[string]string, len(keys)) - for _, parameter := range keys { - value, ok := p.Properties[parameter] - if ok { - params[parameter] = value - } - } - return params -} diff --git a/internal/trust/mapping_test.go b/internal/trust/mapping_test.go deleted file mode 100644 index 9b4c6274..00000000 --- a/internal/trust/mapping_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package trust - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestOIDCMappingr_GetIntrospectParameters(t *testing.T) { - tests := []struct { - name string - oidcMapping OIDCMapping - keys []string - wantParams map[string]string - }{ - { - name: "returns matching parameters", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - "client_secret": "my-secret", - "scope": "openid", - }, - }, - keys: []string{"client_id", "client_secret"}, - wantParams: map[string]string{ - "client_id": "my-client-id", - "client_secret": "my-secret", - }, - }, - { - name: "skips missing parameters", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{"client_id", "missing_key"}, - wantParams: map[string]string{ - "client_id": "my-client-id", - }, - }, - { - name: "returns empty map when no keys provided", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{}, - wantParams: map[string]string{}, - }, - { - name: "returns empty map when properties is nil", - oidcMapping: OIDCMapping{ - Properties: nil, - }, - keys: []string{"client_id"}, - wantParams: map[string]string{}, - }, - { - name: "returns empty map when no keys match", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{"non_existent_key"}, - wantParams: map[string]string{}, - }, - { - name: "handles all keys matching", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "key1": "value1", - "key2": "value2", - "key3": "value3", - }, - }, - keys: []string{"key1", "key2", "key3"}, - wantParams: map[string]string{ - "key1": "value1", - "key2": "value2", - "key3": "value3", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.oidcMapping.GetIntrospectParameters(tt.keys) - assert.Equal(t, tt.wantParams, got) - }) - } -} diff --git a/internal/trust/repository.go b/internal/trust/repository.go deleted file mode 100644 index a81d3472..00000000 --- a/internal/trust/repository.go +++ /dev/null @@ -1,11 +0,0 @@ -package trust - -import "context" - -// OIDCMappingRepository allows to read OIDC mapping data for a tenant stored in the context. -type OIDCMappingRepository interface { - Get(ctx context.Context, tenantID string) (OIDCMapping, error) - Create(ctx context.Context, tenantID string, mapping OIDCMapping) error - Delete(ctx context.Context, tenantID string) error - Update(ctx context.Context, tenantID string, mapping OIDCMapping) error -} diff --git a/internal/trust/service.go b/internal/trust/service.go deleted file mode 100644 index 8c4a2686..00000000 --- a/internal/trust/service.go +++ /dev/null @@ -1,95 +0,0 @@ -package trust - -import ( - "context" - "errors" - "fmt" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -type Service struct { - repository OIDCMappingRepository -} - -func NewService(repo OIDCMappingRepository) *Service { - return &Service{ - repository: repo, - } -} - -func (s *Service) ApplyMapping(ctx context.Context, tenantID string, mapping OIDCMapping) error { - _, err := s.repository.Get(ctx, tenantID) - if err != nil { - err = s.repository.Create(ctx, tenantID, mapping) - if err != nil { - return fmt.Errorf("creating mapping for tenant: %w", err) - } - } else { - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - return fmt.Errorf("updating mapping for tenant: %w", err) - } - } - - return nil -} - -// BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. -// If the mapping is already blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. -func (s *Service) BlockMapping(ctx context.Context, tenantID string) error { - mapping, err := s.repository.Get(ctx, tenantID) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("getting mapping for tenant: %w", err) - } - if mapping.Blocked { - return nil - } - mapping.Blocked = true - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("updating mapping for blocking tenant: %w", err) - } - return nil -} - -func (s *Service) RemoveMapping(ctx context.Context, tenantID string) error { - err := s.repository.Delete(ctx, tenantID) - if err != nil { - return fmt.Errorf("deleting mapping for tenant: %w", err) - } - - return nil -} - -// UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. -// If the mapping is not blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. -func (s *Service) UnblockMapping(ctx context.Context, tenantID string) error { - mapping, err := s.repository.Get(ctx, tenantID) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("getting mapping for tenant: %w", err) - } - if !mapping.Blocked { - return nil - } - mapping.Blocked = false - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("updating mapping for unblocking tenant: %w", err) - } - return nil -} diff --git a/internal/trust/service_test.go b/internal/trust/service_test.go deleted file mode 100644 index f216d22e..00000000 --- a/internal/trust/service_test.go +++ /dev/null @@ -1,540 +0,0 @@ -package trust_test - -import ( - "context" - "os" - "testing" - - "github.com/gofrs/uuid/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/trust" -) - -var repo trust.OIDCMappingRepository - -const ( - requestURI = "http://cmk.example.com/ui" - jwksURI = "http://jwks.example.com" -) - -func TestMain(m *testing.M) { - ctx := context.Background() - r, err := createRepo(ctx) - if err != nil { - slogctx.Error(ctx, "error while creating repo", "error", err) - } - - repo = r - - code := m.Run() - os.Exit(code) -} - -func TestService_ApplyMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if", func(t *testing.T) { - t.Run("the mapping does not exist", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expMapping, actMapping) - }) - - t.Run("the mapping exists", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - - expUpdatedMapping := trust.OIDCMapping{ - IssuerURL: expMapping.IssuerURL, - JWKSURI: "http://updated-jwks.example.com", - Audiences: []string{requestURI, "http://new-aud.example.com"}, - } - - err = subj.ApplyMapping(ctx, expTenantID, expUpdatedMapping) - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expUpdatedMapping, actMapping) - }) - }) - - t.Run("should return error if", func(t *testing.T) { - t.Run("Create returns an error", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - noOfCalls := 0 - wrapper.MockCreate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - assert.Equal(t, expMapping, mapping) - noOfCalls++ - return assert.AnError - } - - subj := trust.NewService(wrapper) - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfCalls) - }) - - t.Run("Update returns an error", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - noOfCalls := 0 - wrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - assert.Equal(t, expMapping, mapping) - noOfCalls++ - return assert.AnError - } - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - err = subj.ApplyMapping(ctx, expTenantID, expMapping) - - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfCalls) - }) - }) -} - -func TestService_BlockMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if ", func(t *testing.T) { - t.Run("the mapping is unblocked", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - subj := trust.NewService(wrapper) - - // when - err = subj.BlockMapping(ctx, expTenantID) - - // then - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.True(t, actMapping.Blocked) - assert.Equal(t, expUnblockedMapping.IssuerURL, actMapping.IssuerURL) - assert.Equal(t, expUnblockedMapping.Audiences, actMapping.Audiences) - assert.Equal(t, expUnblockedMapping.JWKSURI, actMapping.JWKSURI) - }) - - t.Run("the mapping is blocked then it should not call Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 0, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) - }) - t.Run("the mapping is not found during the Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - // delete the mapping before updating to return an error - err := repoWrapper.Repo.Delete(ctx, expTenantID) - assert.NoError(t, err) - return nil - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 1, noOfUpdateCalls) - }) - t.Run("the mapping is not found", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - subj := trust.NewService(repoWrapper) - - // when - err := subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - }) - }) - - t.Run("should return error", func(t *testing.T) { - t.Run("if Get returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - noOfGetCalls := 0 - repoWrapper.MockGet = func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - assert.Equal(t, expTenantID, tenantID) - noOfGetCalls++ - return trust.OIDCMapping{}, assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err := subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfGetCalls) - }) - - t.Run("if Update returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expMapping, actMapping) - }) - }) -} - -func TestService_UnblockMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if ", func(t *testing.T) { - t.Run("the mapping is blocked", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - subj := trust.NewService(wrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.False(t, actMapping.Blocked) - assert.Equal(t, expBlockedMapping.IssuerURL, actMapping.IssuerURL) - assert.Equal(t, expBlockedMapping.Audiences, actMapping.Audiences) - assert.Equal(t, expBlockedMapping.JWKSURI, actMapping.JWKSURI) - }) - - t.Run("the mapping is unblocked then it should not call Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 0, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.False(t, actMapping.Blocked) - }) - t.Run("the mapping is not found during the Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - // delete the mapping before updating to return an error - err := repoWrapper.Repo.Delete(ctx, expTenantID) - assert.NoError(t, err) - return nil - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 1, noOfUpdateCalls) - }) - t.Run("the mapping is not found", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - subj := trust.NewService(repoWrapper) - - // when - err := subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - }) - }) - t.Run("should return error", func(t *testing.T) { - t.Run("if Get returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - mockRepo := &RepoWrapper{Repo: repo} - - noOfGetTenantCalls := 0 - mockRepo.MockGet = func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - assert.Equal(t, expTenantID, tenantID) - noOfGetTenantCalls++ - return trust.OIDCMapping{}, assert.AnError - } - subj := trust.NewService(mockRepo) - - // when - err := subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfGetTenantCalls) - }) - - t.Run("if Update returns an error", func(t *testing.T) { - // given - expTenantIDtoUpdate := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantIDtoUpdate, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantIDtoUpdate, tenantID) - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantIDtoUpdate) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) - assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) - }) - }) -} - -func TestService_RemoveMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if", func(t *testing.T) { - t.Run("the mapping exists", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expMapping) - require.NoError(t, err) - - subj := trust.NewService(wrapper) - - // when - err = subj.RemoveMapping(ctx, expTenantID) - - // then - assert.NoError(t, err) - - // verify the mapping was deleted - _, err = wrapper.Repo.Get(ctx, expTenantID) - assert.Error(t, err) - }) - }) - - t.Run("should return error if", func(t *testing.T) { - t.Run("the mapping does not exist", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - // when - err := subj.RemoveMapping(ctx, expTenantID) - - // then - assert.Error(t, err) - }) - - t.Run("Delete returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - wrapper := &RepoWrapper{Repo: repo} - - noOfDeleteCalls := 0 - wrapper.MockDelete = func(ctx context.Context, tenantID string) error { - assert.Equal(t, expTenantID, tenantID) - noOfDeleteCalls++ - return assert.AnError - } - - subj := trust.NewService(wrapper) - - // when - err := subj.RemoveMapping(ctx, expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfDeleteCalls) - }) - }) -} diff --git a/internal/trust/trustsql/errors.go b/internal/trust/trustsql/errors.go deleted file mode 100644 index 31eff64b..00000000 --- a/internal/trust/trustsql/errors.go +++ /dev/null @@ -1,18 +0,0 @@ -package trustsql - -import ( - "errors" - - "github.com/jackc/pgx/v5/pgconn" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -func handlePgError(err error) (error, bool) { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == "23505" { - return serviceerr.ErrConflict, true - } - - return err, false -} diff --git a/internal/trust/trustsql/errors_test.go b/internal/trust/trustsql/errors_test.go deleted file mode 100644 index 7d339eb3..00000000 --- a/internal/trust/trustsql/errors_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package trustsql - -import ( - "errors" - "testing" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -var errUnknown = errors.New("unknown error") - -func Test_handlePgError(t *testing.T) { - tests := []struct { - name string - inputErr error - errTarget error - wantOk bool - }{ - { - name: "23505 error", - inputErr: &pgconn.PgError{Code: "23505"}, - errTarget: serviceerr.ErrConflict, - wantOk: true, - }, - { - name: "Unknown error", - inputErr: errUnknown, - errTarget: errUnknown, - wantOk: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotErr, ok := handlePgError(tt.inputErr) - if !assert.ErrorIsf(t, gotErr, tt.errTarget, "handlePgError() error %v", gotErr) { - return - } - - if !assert.Equalf(t, tt.wantOk, ok, "handlePgError() OK = %v, want = %v", ok, tt.wantOk) { - return - } - }) - } -} diff --git a/internal/trust/trustsql/internal/queries/models.go b/internal/trust/trustsql/internal/queries/models.go deleted file mode 100644 index 8b643452..00000000 --- a/internal/trust/trustsql/internal/queries/models.go +++ /dev/null @@ -1,20 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.31.1 - -package queries - -import ( - "github.com/jackc/pgx/v5/pgtype" -) - -type Trust struct { - TenantID string `db:"tenant_id"` - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - CreatedAt pgtype.Timestamp `db:"created_at"` - ClientID pgtype.Text `db:"client_id"` -} diff --git a/internal/trust/trustsql/internal/queries/queries.sql.go b/internal/trust/trustsql/internal/queries/queries.sql.go deleted file mode 100644 index 9f71b1e8..00000000 --- a/internal/trust/trustsql/internal/queries/queries.sql.go +++ /dev/null @@ -1,141 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.31.1 -// source: queries.sql - -package queries - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" -) - -const createOIDCMapping = `-- name: CreateOIDCMapping :exec -INSERT INTO trust ( - tenant_id, - blocked, - issuer, - jwks_uri, - audiences, - properties, - client_id) -VALUES ( - $1, - $2, - $3, - $4, - COALESCE($5::text[], '{}'::text[]), - $6, - $7) -` - -type CreateOIDCMappingParams struct { - TenantID string `db:"tenant_id"` - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` -} - -func (q *Queries) CreateOIDCMapping(ctx context.Context, arg CreateOIDCMappingParams) error { - _, err := q.db.Exec(ctx, createOIDCMapping, - arg.TenantID, - arg.Blocked, - arg.Issuer, - arg.JwksUri, - arg.Audiences, - arg.Properties, - arg.ClientID, - ) - return err -} - -const deleteOIDCMapping = `-- name: DeleteOIDCMapping :execrows -DELETE FROM trust -WHERE tenant_id = $1 -` - -func (q *Queries) DeleteOIDCMapping(ctx context.Context, tenantID string) (int64, error) { - result, err := q.db.Exec(ctx, deleteOIDCMapping, tenantID) - if err != nil { - return 0, err - } - return result.RowsAffected(), nil -} - -const getOIDCMapping = `-- name: GetOIDCMapping :one -SELECT - issuer, - blocked, - jwks_uri, - audiences, - properties, - client_id -FROM trust -WHERE tenant_id = $1 -` - -type GetOIDCMappingRow struct { - Issuer string `db:"issuer"` - Blocked bool `db:"blocked"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` -} - -func (q *Queries) GetOIDCMapping(ctx context.Context, tenantID string) (GetOIDCMappingRow, error) { - row := q.db.QueryRow(ctx, getOIDCMapping, tenantID) - var i GetOIDCMappingRow - err := row.Scan( - &i.Issuer, - &i.Blocked, - &i.JwksUri, - &i.Audiences, - &i.Properties, - &i.ClientID, - ) - return i, err -} - -const updateOIDCMapping = `-- name: UpdateOIDCMapping :execrows -UPDATE trust -SET - blocked = $1, - issuer = $2, - jwks_uri = $3, - audiences = COALESCE($4::text[], '{}'::text[]), - properties = $5, - client_id = $6 -WHERE - tenant_id = $7 -` - -type UpdateOIDCMappingParams struct { - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` - TenantID string `db:"tenant_id"` -} - -func (q *Queries) UpdateOIDCMapping(ctx context.Context, arg UpdateOIDCMappingParams) (int64, error) { - result, err := q.db.Exec(ctx, updateOIDCMapping, - arg.Blocked, - arg.Issuer, - arg.JwksUri, - arg.Audiences, - arg.Properties, - arg.ClientID, - arg.TenantID, - ) - if err != nil { - return 0, err - } - return result.RowsAffected(), nil -} diff --git a/internal/trust/trustsql/repository.go b/internal/trust/trustsql/repository.go deleted file mode 100644 index b2d71625..00000000 --- a/internal/trust/trustsql/repository.go +++ /dev/null @@ -1,157 +0,0 @@ -package trustsql - -import ( - "context" - "encoding/json" - "errors" - "fmt" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - "go.opentelemetry.io/otel" - - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql/internal/queries" -) - -type Repository struct { - db *pgxpool.Pool - queries *queries.Queries -} - -func NewRepository(db *pgxpool.Pool) *Repository { - return &Repository{ - db: db, - queries: queries.New(db), - } -} - -func (r *Repository) Get(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_mapping_sql") - defer span.End() - - row, err := r.queries.GetOIDCMapping(ctx, tenantID) - if err != nil { - span.RecordError(err) - if errors.Is(err, pgx.ErrNoRows) { - return trust.OIDCMapping{}, serviceerr.ErrNotFound - } - - return trust.OIDCMapping{}, err - } - - properties := make(map[string]string) - if len(row.Properties) > 0 { - err := json.Unmarshal(row.Properties, &properties) - if err != nil { - return trust.OIDCMapping{}, fmt.Errorf("unmarshalling properties: %w", err) - } - } - - return trust.OIDCMapping{ - IssuerURL: row.Issuer, - Blocked: row.Blocked, - JWKSURI: row.JwksUri, - Audiences: row.Audiences, - Properties: properties, - ClientID: row.ClientID.String, - }, nil -} - -func (r *Repository) Create(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "create_oidc_mapping_sql") - defer span.End() - - properties, err := r.marshalProperties(mapping) - if err != nil { - return fmt.Errorf("marshaling properties: %w", err) - } - - if err := r.queries.CreateOIDCMapping(ctx, queries.CreateOIDCMappingParams{ - TenantID: tenantID, - Blocked: mapping.Blocked, - Issuer: mapping.IssuerURL, - JwksUri: mapping.JWKSURI, - Audiences: mapping.Audiences, - Properties: properties, - ClientID: pgTextOrNull(mapping.ClientID), - }); err != nil { - span.RecordError(err) - if err, ok := handlePgError(err); ok { - return err - } - - return fmt.Errorf("inserting into trust: %w", err) - } - - return nil -} - -func (r *Repository) Delete(ctx context.Context, tenantID string) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "delete_oidc_mapping_sql") - defer span.End() - - affected, err := r.queries.DeleteOIDCMapping(ctx, tenantID) - if err != nil { - span.RecordError(err) - return fmt.Errorf("executing sql query: %w", err) - } - - if affected == 0 { - return serviceerr.ErrNotFound - } - - return nil -} - -func (r *Repository) Update(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "update_oidc_mapping_sql") - defer span.End() - - properties, err := r.marshalProperties(mapping) - if err != nil { - span.RecordError(err) - return err - } - - affected, err := r.queries.UpdateOIDCMapping(ctx, queries.UpdateOIDCMappingParams{ - Blocked: mapping.Blocked, - Issuer: mapping.IssuerURL, - JwksUri: mapping.JWKSURI, - Audiences: mapping.Audiences, - Properties: properties, - ClientID: pgTextOrNull(mapping.ClientID), - TenantID: tenantID, - }) - if err != nil { - span.RecordError(err) - return fmt.Errorf("updating trust: %w", err) - } - - if affected == 0 { - return serviceerr.ErrNotFound - } - - return nil -} - -func (r *Repository) marshalProperties(mapping trust.OIDCMapping) ([]byte, error) { - propsBytes, err := json.Marshal(mapping.Properties) - if err != nil { - return nil, fmt.Errorf("marshaling json: %w", err) - } - return propsBytes, nil -} - -func pgTextOrNull(s string) pgtype.Text { - return pgtype.Text{ - String: s, - Valid: s != "", - } -} diff --git a/internal/trust/trustsql/repository_test.go b/internal/trust/trustsql/repository_test.go deleted file mode 100644 index 74169969..00000000 --- a/internal/trust/trustsql/repository_test.go +++ /dev/null @@ -1,302 +0,0 @@ -package trustsql_test - -import ( - "context" - "errors" - "fmt" - "os" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" -) - -var dbPool *pgxpool.Pool - -func TestMain(m *testing.M) { - ctx := context.Background() - - pool, _, terminate := postgrestest.Start(ctx) - defer terminate(ctx) - - dbPool = pool - - code := m.Run() - os.Exit(code) -} - -func TestRepository_Get(t *testing.T) { - tests := []struct { - name string - tenantID string - wantMapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Success", - tenantID: "tenant1-id", - wantMapping: trust.OIDCMapping{ - IssuerURL: "url-one", - Blocked: false, - JWKSURI: "", - Audiences: make([]string, 0), - Properties: make(map[string]string), - }, - assertErr: assert.NoError, - }, - { - name: "Error does not exist", - tenantID: "does-not-exist", - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := trustsql.NewRepository(dbPool) - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Get() error %v", err)) || err != nil { - assert.Zerof(t, gotMapping, "Repository.Get() extected zero value if an error is returned, got %v", gotMapping) - return - } - - assert.Equal(t, tt.wantMapping, gotMapping, "Repository.Get()") - }) - } -} - -func TestRepository_Create(t *testing.T) { - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Create succeeds", - tenantID: "tenant-id-create-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - Properties: map[string]string{ - "prop1": "prop1val", - }, - }, - assertErr: assert.NoError, - }, - { - name: "Duplicate", - tenantID: "tenant1-id", - mapping: trust.OIDCMapping{ - IssuerURL: "url-one", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - Properties: map[string]string{ - "prop1": "prop1val", - }, - }, - assertErr: assert.Error, - }, - { - name: "Create without JWKSURI and Audiences succeeds", - tenantID: "tenant-id-create-without-jwks-aud-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-2.example.com", - Blocked: false, - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - { - name: "Create without JWKSURI succeeds", - tenantID: "tenant-id-create-without-jwks-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-3.example.com", - Blocked: false, - Audiences: []string{"cmk.example.com"}, - }, - assertErr: assert.NoError, - }, - { - name: "Create without Audiences succeeds", - tenantID: "tenant-id-create-without-aud-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-4.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Given - r := trustsql.NewRepository(dbPool) - - // When - err := r.Create(t.Context(), tt.tenantID, tt.mapping) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Create() error %v", err)) || err != nil { - return - } - - // Then - mapping, err := r.Get(t.Context(), tt.tenantID) - require.NoError(t, err) - - if diff := cmp.Diff(tt.mapping, mapping); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) - } - }) - } -} - -func TestRepository_Delete(t *testing.T) { - const tenantID = "tenant-id-delete-success" - - mapping := trust.OIDCMapping{ - IssuerURL: "http://oidc-to-delete.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - } - - r := trustsql.NewRepository(dbPool) - err := r.Create(t.Context(), tenantID, mapping) - require.NoError(t, err, "Inserting test data") - - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Delete tenant", - tenantID: tenantID, - mapping: mapping, - assertErr: assert.NoError, - }, - { - name: "Error does not exist", - tenantID: "does-not-exist", - mapping: trust.OIDCMapping{IssuerURL: "does-not-exist"}, - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := r.Delete(t.Context(), tt.tenantID) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Delete() error %v", err)) || err != nil { - return - } - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - if !errors.Is(err, serviceerr.ErrNotFound) { - t.Error("The mapping is expected to be deleted") - } - assert.Zero(t, gotMapping, "The mapping is expected to be deleted, instead a value is returned") - }) - } -} - -func TestRepository_Update(t *testing.T) { - const tenantID = "tenant-id-update-success" - - mapping := trust.OIDCMapping{ - IssuerURL: "http://oidc-to-update.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - } - - r := trustsql.NewRepository(dbPool) - err := r.Create(t.Context(), tenantID, mapping) - require.NoError(t, err, "Inserting test data") - - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Update succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.NoError, - }, - { - name: "Does not exist", - tenantID: "does-not-exist", - mapping: trust.OIDCMapping{ - IssuerURL: "does-not-exist", - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.Error, - }, - { - name: "Update without JWKSURI and Audiences succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - { - name: "Update without JWKSURI succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.NoError, - }, - { - name: "Update without Audiences succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := r.Update(t.Context(), tt.tenantID, tt.mapping) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Update() error %v", err)) || err != nil { - return - } - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - require.NoError(t, err) - - if diff := cmp.Diff(tt.mapping, gotMapping); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) - } - }) - } -} diff --git a/modules.go b/modules.go new file mode 100644 index 00000000..280652e3 --- /dev/null +++ b/modules.go @@ -0,0 +1,62 @@ +package sessionmanager + +import ( + "fmt" + "iter" + "maps" + "sync" +) + +var ( + modules = make(map[string]ModuleInfo) + modulesMu sync.RWMutex +) + +func RegisterModule(module Module) { + modulesMu.Lock() + defer modulesMu.Unlock() + + info := module.Module() + + if _, ok := modules[info.ID]; ok { + panic(`module "` + info.ID + `" has already been registered`) + } + + modules[info.ID] = info +} + +func GetModule(id string) (ModuleInfo, error) { + modulesMu.RLock() + defer modulesMu.RUnlock() + mod, ok := modules[id] + if !ok { + return ModuleInfo{}, fmt.Errorf("module %q is not registered", id) + } + + return mod, nil +} + +func Modules() iter.Seq[ModuleInfo] { + modulesMu.RLock() + defer modulesMu.RUnlock() + + return maps.Values(maps.Clone(modules)) +} + +type Module interface { + Module() ModuleInfo +} + +type ModuleInfo struct { + ID string + New func() Module +} + +type Provisioner interface { + Provision(ctx *Context) error +} + +type App interface { + Start() error + Stop() error +} diff --git a/modules/app/grpcserver/module.go b/modules/app/grpcserver/module.go new file mode 100644 index 00000000..613f974f --- /dev/null +++ b/modules/app/grpcserver/module.go @@ -0,0 +1,160 @@ +// Package grpcserver provides the app.module.grpcserver app module: a +// long-running gRPC server that hosts service modules registered through its +// services: config block. The Service interface is exported here so that +// service modules can satisfy it without importing google.golang.org/grpc +// from the top-level sessionmanager package. +package grpcserver + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/openkcm/common-sdk/pkg/commongrpc" + "google.golang.org/grpc" + + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +const moduleID = "app.module.grpcserver" + +// Service is the interface that every gRPC service module loaded under an +// app.module.grpcserver entry must satisfy. The grpc app calls Register on +// each child service, in declaration order, before invoking Serve. +type Service interface { + Register(s grpc.ServiceRegistrar) +} + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the gRPC server app. Its lifecycle: +// - Provision: load every service module listed under services: from the +// config; type-assert each against Service; collect them in declaration +// order. +// - Start: build the underlying *grpc.Server via commongrpc.NewServer using +// the top-level cfg.GRPC block, register every collected service onto it, +// listen on cfg.GRPC.Address, and begin Serve in a goroutine. +// - Stop: GracefulStop bounded by cfg.GRPC.ShutdownTimeout; if the timeout +// fires, fall back to a forceful Stop. +type Module struct { + Mod string `yaml:"module"` + Services []*config.ServiceCfg `yaml:"services"` + + ctx context.Context //nolint:containedctx + cfg *config.Config + services []Service + + server *grpc.Server + listener net.Listener + + stopOnce sync.Once + stopErr error + + serveDone chan struct{} +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + m.ctx = ctx + m.cfg = cfg + + if len(m.Services) == 0 { + return errors.New("app.module.grpcserver requires at least one service under services") + } + + m.services = make([]Service, 0, len(m.Services)) + for i, svcCfg := range m.Services { + mod, err := ctx.LoadModule(svcCfg) + if err != nil { + return fmt.Errorf("loading service[%d] %q: %w", i, svcCfg.Module(), err) + } + svc, ok := mod.(Service) + if !ok { + return fmt.Errorf("service[%d] module %q does not implement grpcserver.Service", i, svcCfg.Module()) + } + m.services = append(m.services, svc) + } + + return nil +} + +func (m *Module) Start() error { + m.server = commongrpc.NewServer(m.ctx, &m.cfg.GRPC.GRPCServer) + + for _, svc := range m.services { + svc.Register(m.server) + } + + listener, err := new(net.ListenConfig).Listen(m.ctx, "tcp", m.cfg.GRPC.Address) + if err != nil { + return fmt.Errorf("listening on %s: %w", m.cfg.GRPC.Address, err) + } + m.listener = listener + + slogctx.Info(m.ctx, "Starting a gRPC listener", "address", listener.Addr().String()) + + m.serveDone = make(chan struct{}) + go func() { + defer close(m.serveDone) + if err := m.server.Serve(listener); err != nil { + slogctx.Error(m.ctx, "gRPC server stopped with error", "error", err) + } + }() + + return nil +} + +func (m *Module) Stop() error { + m.stopOnce.Do(func() { + if m.server == nil { + return + } + + gracefulDone := make(chan struct{}) + go func() { + m.server.GracefulStop() + close(gracefulDone) + }() + + timeout := m.cfg.GRPC.ShutdownTimeout + if timeout <= 0 { + <-gracefulDone + } else { + tctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + select { + case <-gracefulDone: + case <-tctx.Done(): + slogctx.Warn(m.ctx, "gRPC graceful stop exceeded timeout; forcing Stop", "timeout", timeout) + m.server.Stop() + <-gracefulDone + } + } + + if m.serveDone != nil { + <-m.serveDone + } + }) + return m.stopErr +} diff --git a/modules/app/grpcserver/module_test.go b/modules/app/grpcserver/module_test.go new file mode 100644 index 00000000..be8995e9 --- /dev/null +++ b/modules/app/grpcserver/module_test.go @@ -0,0 +1,131 @@ +package grpcserver_test + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/modules/app/grpcserver" +) + +// fakeService satisfies grpcserver.Service. It records that Register was +// called on a non-nil ServiceRegistrar. The test then exercises Start/Stop +// without registering any real proto services. +type fakeService struct { + mu sync.Mutex + registered bool +} + +func (f *fakeService) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: "test.fake.service", + New: func() sessionmanager.Module { return f }, + } +} + +func (f *fakeService) Register(_ grpc.ServiceRegistrar) { + f.mu.Lock() + defer f.mu.Unlock() + f.registered = true +} + +// notService is a Module that does NOT implement grpcserver.Service. +type notService struct{ id string } + +func (n *notService) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: n.id, + New: func() sessionmanager.Module { return n }, + } +} + +func newCtx(t *testing.T) (*sessionmanager.Context, *config.Config) { + t.Helper() + cfg := &config.Config{} + cfg.GRPC = config.GRPCServer{ + GRPCServer: commoncfg.GRPCServer{Address: "127.0.0.1:0"}, + ShutdownTimeout: 2 * time.Second, + } + // Find a free port the deterministic way. + l, err := new(net.ListenConfig).Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + cfg.GRPC.Address = addr + + ctx, cancel := sessionmanager.NewContext(t.Context()) + t.Cleanup(func() { cancel(nil) }) + ctx = config.WithContext(ctx, cfg) + return ctx, cfg +} + +func TestModule_StartRegistersServicesAndStops(t *testing.T) { + ctx, _ := newCtx(t) + + fakeID := "test.fake.service." + t.Name() + svc := &fakeService{} + sessionmanager.RegisterModule(&customMod{id: fakeID, mod: svc}) + + m := &grpcserver.Module{ + Services: []*config.ServiceCfg{newSvcCfg(fakeID)}, + } + require.NoError(t, m.Provision(ctx)) + require.NoError(t, m.Start()) + + assert.True(t, svc.registered, "service must be registered before Serve") + + require.NoError(t, m.Stop()) + + // Stop is idempotent. + require.NoError(t, m.Stop()) +} + +func TestModule_NonServiceUnderServicesIsRejected(t *testing.T) { + ctx, _ := newCtx(t) + + id := "test.notservice." + t.Name() + sessionmanager.RegisterModule(&customMod{id: id, mod: ¬Service{id: id}}) + + m := &grpcserver.Module{ + Services: []*config.ServiceCfg{newSvcCfg(id)}, + } + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not implement") +} + +func TestModule_EmptyServicesRejected(t *testing.T) { + ctx, _ := newCtx(t) + m := &grpcserver.Module{} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one service") +} + +// customMod is a tiny ExtensionConfig + module registration helper. +type customMod struct { + id string + mod sessionmanager.Module +} + +func (c *customMod) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: c.id, + New: func() sessionmanager.Module { return c.mod }, + } +} + +// newSvcCfg builds a config.ServiceCfg pointing at the given module ID. +// We bypass koanf entirely; the module just needs Module() to return the ID. +func newSvcCfg(modID string) *config.ServiceCfg { + c := &config.ServiceCfg{Mod: modID} + return c +} diff --git a/modules/credentials/oauth2/module.go b/modules/credentials/oauth2/module.go new file mode 100644 index 00000000..5a95747e --- /dev/null +++ b/modules/credentials/oauth2/module.go @@ -0,0 +1,90 @@ +// Package oauth2 provides the credentials.module.oauth2 module: a +// credentials.Builder that produces transport credentials for OAuth2/OIDC +// client authentication. Source data lives under sessionManager.clientAuth in +// the top-level config and is read via config.FromContext. +package oauth2 + +import ( + "errors" + "fmt" + "log/slog" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" +) + +const moduleID = "credentials.module.oauth2" + +const ( + authMTLS = "mtls" + authClientSecret = "client_secret" + authClientSecretPost = "client_secret_post" + authInsecure = "insecure" +) + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the credentials.module.oauth2 module. It exposes a +// credentials.Builder constructed from sessionManager.clientAuth. +type Module struct { + Mod string `yaml:"module"` + + builder credentials.Builder +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + + clientAuth := cfg.SessionManager.ClientAuth + switch clientAuth.Type { + case authMTLS: + tlsConfig, err := commoncfg.LoadMTLSConfig(clientAuth.MTLS) + if err != nil { + return fmt.Errorf("loading mTLS config: %w", err) + } + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewTLS(clientID, tlsConfig) + } + case authClientSecret, authClientSecretPost: + secret, err := commoncfg.LoadValueFromSourceRef(clientAuth.ClientSecret) + if err != nil { + return fmt.Errorf("loading client secret: %w", err) + } + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewClientSecretPost(clientID, string(secret)) + } + case authInsecure: + slog.Warn("insecure credentials are used. Do not use this in production") + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewInsecure(clientID) + } + default: + return fmt.Errorf("unknown client auth type %q", clientAuth.Type) + } + + return nil +} + +// Builder returns the credentials.Builder produced during Provision. +func (m *Module) Builder() credentials.Builder { + return m.builder +} diff --git a/modules/credentials/oauth2/module_test.go b/modules/credentials/oauth2/module_test.go new file mode 100644 index 00000000..d811526a --- /dev/null +++ b/modules/credentials/oauth2/module_test.go @@ -0,0 +1,81 @@ +package oauth2_test + +import ( + "testing" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + credentialsoauth2 "github.com/openkcm/session-manager/modules/credentials/oauth2" +) + +func TestModule_RegistrationAndID(t *testing.T) { + info, err := sessionmanager.GetModule("credentials.module.oauth2") + require.NoError(t, err) + assert.Equal(t, "credentials.module.oauth2", info.ID) + + mod := info.New() + require.NotNil(t, mod) + _, ok := mod.(*credentialsoauth2.Module) + assert.True(t, ok, "New() must return *Module") +} + +func provisionWithAuth(t *testing.T, auth config.ClientAuth) (*credentialsoauth2.Module, error) { + t.Helper() + cfg := &config.Config{} + cfg.SessionManager.ClientAuth = auth + + ctx, cancel := sessionmanager.NewContext(t.Context()) + t.Cleanup(func() { cancel(nil) }) + ctx = config.WithContext(ctx, cfg) + + m := new(credentialsoauth2.Module) + return m, m.Provision(ctx) +} + +func TestModule_ProvisionInsecure(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{Type: "insecure"}) + require.NoError(t, err) + require.NotNil(t, m.Builder()) + + creds := m.Builder()("client-id") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionClientSecret(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{ + Type: "client_secret", + ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "shh"}, + }) + require.NoError(t, err) + creds := m.Builder()("cid") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionClientSecretPost(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{ + Type: "client_secret_post", + ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "shh"}, + }) + require.NoError(t, err) + creds := m.Builder()("cid") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionUnknownTypeFails(t *testing.T) { + _, err := provisionWithAuth(t, config.ClientAuth{Type: "totally-bogus"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "totally-bogus") +} + +func TestModule_ProvisionWithoutConfigFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := new(credentialsoauth2.Module) + err := m.Provision(ctx) + require.Error(t, err) +} diff --git a/modules/database/pgxpool/module.go b/modules/database/pgxpool/module.go new file mode 100644 index 00000000..f8865972 --- /dev/null +++ b/modules/database/pgxpool/module.go @@ -0,0 +1,107 @@ +package pgxpool + +import ( + "context" + "database/sql" + "fmt" + + "github.com/exaring/otelpgx" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "database.module.pgxpool" + +func newModule() sessionmanager.Module { + return new(PostgresModule) +} + +func init() { + sessionmanager.RegisterModule(new(PostgresModule)) +} + +type PostgresModule struct { + Mod string `yaml:"module"` + Name string `yaml:"name"` + Port string `yaml:"port"` + Host commoncfg.SourceRef `yaml:"host"` + User commoncfg.SourceRef `yaml:"user"` + Password commoncfg.SourceRef `yaml:"password"` + + db *pgxpool.Pool +} + +func (m *PostgresModule) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(m.db) +} + +func (m *PostgresModule) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + return m.db.Exec(ctx, sql, args...) +} + +func (m *PostgresModule) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return m.db.Query(ctx, sql, args...) +} + +func (m *PostgresModule) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return m.db.QueryRow(ctx, sql, args...) +} + +func (m *PostgresModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *PostgresModule) Provision(ctx *sessionmanager.Context) error { + connStr, err := m.makeConnStr() + if err != nil { + return fmt.Errorf("making dsn from config: %w", err) + } + + pgxpoolCfg, err := pgxpool.ParseConfig(connStr) + if err != nil { + return fmt.Errorf("parsing pgxpool config: %w", err) + } + + pgxpoolCfg.ConnConfig.Tracer = otelpgx.NewTracer() + + m.db, err = pgxpool.NewWithConfig(ctx, pgxpoolCfg) + if err != nil { + return fmt.Errorf("failed to initialise pgxpool connection: %w", err) + } + + if err := otelpgx.RecordStats(m.db); err != nil { + return fmt.Errorf("recording database stat: %w", err) + } + + return nil +} + +func (m *PostgresModule) makeConnStr() (string, error) { + host, err := commoncfg.LoadValueFromSourceRef(m.Host) + if err != nil { + return "", fmt.Errorf("loading db host: %w", err) + } + + user, err := commoncfg.LoadValueFromSourceRef(m.User) + if err != nil { + return "", fmt.Errorf("loading db user: %w", err) + } + + password, err := commoncfg.LoadValueFromSourceRef(m.Password) + if err != nil { + return "", fmt.Errorf("loading db password: %w", err) + } + + return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s", + host, user, string(password), m.Name, m.Port), nil +} + +var _ sessionmanager.Database = (*PostgresModule)(nil) diff --git a/modules/grpc/oidcmapping/import_test.go b/modules/grpc/oidcmapping/import_test.go new file mode 100644 index 00000000..51d6268c --- /dev/null +++ b/modules/grpc/oidcmapping/import_test.go @@ -0,0 +1,12 @@ +package oidcmapping_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/modules/grpc/oidcmapping/module.go b/modules/grpc/oidcmapping/module.go new file mode 100644 index 00000000..5aaf629b --- /dev/null +++ b/modules/grpc/oidcmapping/module.go @@ -0,0 +1,62 @@ +// Package oidcmapping provides the service.module.grpc.oidcmapping module: +// a gRPC service module that registers the legacy +// kms.api.cmk.sessionmanager.oidcmapping.v1.Service proto onto a +// grpc.ServiceRegistrar supplied by app.module.grpcserver. It exists for +// backward compatibility with clients that have not migrated to the +// kms.api.cmk.sessionmanager.trustmapping.v1.Service proto. +package oidcmapping + +import ( + "fmt" + + "google.golang.org/grpc" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "service.module.grpc.oidcmapping" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the service.module.grpc.oidcmapping module. It owns a Server that +// adapts the legacy oidcmapping proto onto sessionmanager.Trust and resolves +// its single dependency by ID via ctx.GetModule. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + m.server = NewServer(trust) + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + oidcmappingv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/oidcmapping/module_test.go b/modules/grpc/oidcmapping/module_test.go new file mode 100644 index 00000000..f730b323 --- /dev/null +++ b/modules/grpc/oidcmapping/module_test.go @@ -0,0 +1,77 @@ +package oidcmapping_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + ommod "github.com/openkcm/session-manager/modules/grpc/oidcmapping" + _ "github.com/openkcm/session-manager/modules/standard" +) + +type stubTrust struct{} + +func (stubTrust) Apply(context.Context, *trustv1.Trust) error { return nil } +func (stubTrust) Block(context.Context, string) error { return nil } +func (stubTrust) Remove(context.Context, string) error { return nil } +func (stubTrust) Unblock(context.Context, string) error { return nil } + +var errStubTrustGet = errors.New("stub get not implemented") + +func (stubTrust) Get(context.Context, string) (*trustv1.Trust, error) { + return nil, errStubTrustGet +} + +type stubTrustModule struct { + stubTrust + + id string +} + +func (s *stubTrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return s }, + } +} + +func TestModule_Registration(t *testing.T) { + info, err := sessionmanager.GetModule("service.module.grpc.oidcmapping") + require.NoError(t, err) + assert.Equal(t, "service.module.grpc.oidcmapping", info.ID) +} + +func TestModule_ProvisionResolvesCustomTrust(t *testing.T) { + id := "trust.module.test." + t.Name() + sessionmanager.RegisterModule(&stubTrustModule{id: id}) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&extConfig{moduleID: id}) + require.NoError(t, err) + + m := &ommod.Module{Trust: id} + require.NoError(t, m.Provision(ctx)) +} + +func TestModule_ProvisionMissingTrustFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := &ommod.Module{Trust: "no.such.trust.module"} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "no.such.trust.module") +} + +type extConfig struct{ moduleID string } + +func (c *extConfig) Module() string { return c.moduleID } +func (c *extConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } diff --git a/modules/grpc/oidcmapping/server.go b/modules/grpc/oidcmapping/server.go new file mode 100644 index 00000000..ae55e673 --- /dev/null +++ b/modules/grpc/oidcmapping/server.go @@ -0,0 +1,123 @@ +package oidcmapping + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +type Server struct { + oidcmappingv1.UnimplementedServiceServer + + trust sessionmanager.Trust +} + +func NewServer(trust sessionmanager.Trust) *Server { + return &Server{trust: trust} +} + +func (srv *Server) ApplyOIDCMapping(ctx context.Context, req *oidcmappingv1.ApplyOIDCMappingRequest) (*oidcmappingv1.ApplyOIDCMappingResponse, error) { + oidcBuilder := oidcv1.OIDC_builder{ + Issuer: new(req.GetIssuer()), + Audiences: req.GetAudiences(), + } + if req.JwksUri != nil { + oidcBuilder.JwksUri = new(req.GetJwksUri()) + } + if req.ClientId != nil { + oidcBuilder.ClientId = new(req.GetClientId()) + } + oidc := oidcBuilder.Build() + + trust := trustv1.Trust_builder{ + TenantId: new(req.GetTenantId()), + Oidc: oidc, + }.Build() + + ctx = slogctx.With(ctx, + "tenantId", trust.GetTenantId(), + "issuer", oidc.GetIssuer(), + "jwksUri", oidc.GetJwksUri(), + "audiences", oidc.GetAudiences(), + "client_id", oidc.GetClientId(), + ) + + slogctx.Debug(ctx, "ApplyOIDCMapping called") + + response := &oidcmappingv1.ApplyOIDCMappingResponse{} + + if err := srv.trust.Apply(ctx, trust); err != nil { + slogctx.Error(ctx, "Could not apply trust", "error", err) + if errors.Is(err, serviceerr.ErrNotFound) { + msg := serviceerr.ErrNotFound.Error() + response.Message = &msg + return response, nil + } + + return nil, status.Errorf(codes.Internal, "failed to apply trust: %v", err) + } + + response.Success = true + return response, nil +} + +func (srv *Server) RemoveOIDCMapping(ctx context.Context, req *oidcmappingv1.RemoveOIDCMappingRequest) (*oidcmappingv1.RemoveOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "RemoveOIDCMapping called") + + resp := &oidcmappingv1.RemoveOIDCMappingResponse{} + if err := srv.trust.Remove(ctx, req.GetTenantId()); err != nil { + if !errors.Is(err, serviceerr.ErrNotFound) { + slogctx.Error(ctx, "Could not remove trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to remove trust: "+msg) + } + slogctx.Warn(ctx, "RemoveOIDCMapping is called but the tenant does not exist", "error", err) + } + + resp.Success = true + return resp, nil +} + +func (srv *Server) BlockOIDCMapping(ctx context.Context, req *oidcmappingv1.BlockOIDCMappingRequest) (*oidcmappingv1.BlockOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "BlockOIDCMapping called") + + resp := &oidcmappingv1.BlockOIDCMappingResponse{} + if err := srv.trust.Block(ctx, req.GetTenantId()); err != nil { + slogctx.Error(ctx, "Could not block trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to block trust: "+msg) + } + + resp.Success = true + return resp, nil +} + +func (srv *Server) UnblockOIDCMapping(ctx context.Context, req *oidcmappingv1.UnblockOIDCMappingRequest) (*oidcmappingv1.UnblockOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "UnblockOIDCMapping called") + + resp := &oidcmappingv1.UnblockOIDCMappingResponse{} + if err := srv.trust.Unblock(ctx, req.GetTenantId()); err != nil { + slogctx.Error(ctx, "Could not unblock trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to unblock trust: "+msg) + } + + resp.Success = true + return resp, nil +} diff --git a/modules/grpc/oidcmapping/server_test.go b/modules/grpc/oidcmapping/server_test.go new file mode 100644 index 00000000..1b3ede59 --- /dev/null +++ b/modules/grpc/oidcmapping/server_test.go @@ -0,0 +1,300 @@ +package oidcmapping_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/modules/grpc/oidcmapping" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +func TestNewOIDCMappingServer(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + assert.NotNil(t, server) +} + +func TestApplyOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("forwards issuer, jwks_uri, audiences, client_id when set", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + jwksURI := "https://issuer.example.com/.well-known/jwks.json" + clientID := "client-abc" + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-123", + Issuer: "https://issuer.example.com", + JwksUri: &jwksURI, + Audiences: []string{"audience1", "audience2"}, + ClientId: &clientID, + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-123") + require.NotNil(t, stored) + assert.Equal(t, "tenant-123", stored.GetTenantId()) + require.NotNil(t, stored.GetOidc()) + assert.Equal(t, "https://issuer.example.com", stored.GetOidc().GetIssuer()) + assert.Equal(t, jwksURI, stored.GetOidc().GetJwksUri()) + assert.Equal(t, []string{"audience1", "audience2"}, stored.GetOidc().GetAudiences()) + assert.Equal(t, clientID, stored.GetOidc().GetClientId()) + assert.True(t, stored.GetOidc().HasClientId()) + }) + + t.Run("client_id omitted leaves new oidc.client_id unset", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-no-client", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-no-client") + require.NotNil(t, stored) + require.NotNil(t, stored.GetOidc()) + assert.False(t, stored.GetOidc().HasClientId(), "client_id should remain unset when request omits it") + }) + + t.Run("non-empty properties map is dropped", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + clientID := "client-xyz" + reqWithProps := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-with-props", + Issuer: "https://issuer.example.com", + ClientId: &clientID, + Properties: map[string]string{"foo": "bar", "baz": "qux"}, + } + + resp, err := server.ApplyOIDCMapping(ctx, reqWithProps) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-with-props") + require.NotNil(t, stored) + // The new oidc.OIDC has no properties field; verify the stored Trust matches what + // we'd get by building it from the same request without properties. + expected := trustv1.Trust_builder{ + TenantId: new("tenant-with-props"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + ClientId: new(clientID), + }.Build(), + }.Build() + assert.Equal(t, expected.GetTenantId(), stored.GetTenantId()) + assert.Equal(t, expected.GetOidc().GetIssuer(), stored.GetOidc().GetIssuer()) + assert.Equal(t, expected.GetOidc().GetClientId(), stored.GetOidc().GetClientId()) + }) + + t.Run("ErrNotFound from Apply yields non-success response with message and no gRPC error", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-missing", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, resp.GetSuccess()) + assert.Equal(t, serviceerr.ErrNotFound.Error(), resp.GetMessage()) + }) + + t.Run("other errors map to codes.Internal", func(t *testing.T) { + internalErr := errors.New("database connection failed") + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(internalErr), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-boom", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + assert.Nil(t, resp) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to apply trust") + }) +} + +func TestRemoveOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success removes existing trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("ErrNotFound is idempotent and returns success", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-gone"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("other errors map to codes.Internal", func(t *testing.T) { + deleteErr := errors.New("delete failed") + repo := mocktrust.NewInMemRepository(mocktrust.WithDeleteError(deleteErr)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-boom"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "delete failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to remove trust") + }) +} + +func TestBlockOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success blocks existing trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.BlockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.BlockOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("error maps to codes.Internal with message", func(t *testing.T) { + internalErr := errors.New("database error") + repo := mocktrust.NewInMemRepository(mocktrust.WithGetError(internalErr)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.BlockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.BlockOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "database error") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to block trust") + }) +} + +func TestUnblockOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success unblocks blocked trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.UnblockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.UnblockOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("error maps to codes.Internal with message", func(t *testing.T) { + internalErr := errors.New("update failed") + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existing), + mocktrust.WithUpdateError(internalErr), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.UnblockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.UnblockOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "update failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to unblock trust") + }) +} diff --git a/modules/grpc/session/import_test.go b/modules/grpc/session/import_test.go new file mode 100644 index 00000000..cfe5ca6f --- /dev/null +++ b/modules/grpc/session/import_test.go @@ -0,0 +1,12 @@ +package session_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/modules/grpc/session/module.go b/modules/grpc/session/module.go new file mode 100644 index 00000000..187abc01 --- /dev/null +++ b/modules/grpc/session/module.go @@ -0,0 +1,113 @@ +// Package session provides the service.module.grpc.session module: a gRPC +// service module that registers the kms.api.cmk.sessionmanager.session.v1.Service +// proto onto a grpc.ServiceRegistrar supplied by app.module.grpcserver. +package session + +import ( + "errors" + "fmt" + + "google.golang.org/grpc" + + sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" + internalsession "github.com/openkcm/session-manager/internal/session" +) + +const moduleID = "service.module.grpc.session" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// credentialsBuilder is the interface satisfied by a credentials module +// (e.g. credentials.module.oauth2). +type credentialsBuilder interface { + Builder() credentials.Builder +} + +// Module is the service.module.grpc.session module. It wires its three +// dependencies (trust, session store, credentials) by ID via ctx.GetModule +// and owns a *Server that implements the proto. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + SessionStore string `yaml:"sessionStore" default:"sessionstore.module.valkey"` + Credentials string `yaml:"credentials" default:"credentials.module.oauth2"` + + AllowHttpScheme bool `yaml:"allowHttpScheme"` + QueryParametersIntrospect []string `yaml:"queryParametersIntrospect"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + storeMod, err := ctx.GetModule(m.SessionStore) + if err != nil { + return fmt.Errorf("getting session-store module %q: %w", m.SessionStore, err) + } + repo, ok := storeMod.(internalsession.Repository) + if !ok { + return fmt.Errorf("module %q does not implement session.Repository", m.SessionStore) + } + + credsMod, err := ctx.GetModule(m.Credentials) + if err != nil { + return fmt.Errorf("getting credentials module %q: %w", m.Credentials, err) + } + creds, ok := credsMod.(credentialsBuilder) + if !ok { + return fmt.Errorf("module %q does not expose Builder()", m.Credentials) + } + + opts := []Option{ + WithTransportCredentials(creds.Builder()), + WithAllowHttpScheme(m.AllowHttpScheme), + } + if m.QueryParametersIntrospect != nil { + opts = append(opts, WithQueryParametersIntrospect(m.QueryParametersIntrospect)) + } + + m.server = NewServer( + ctx, + repo, + trust, + cfg.SessionManager.IdleSessionTimeout, + cfg.SessionManager.ClientAuth.ClientID, + opts..., + ) + + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + sessionv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/session/options.go b/modules/grpc/session/options.go new file mode 100644 index 00000000..26731c2b --- /dev/null +++ b/modules/grpc/session/options.go @@ -0,0 +1,23 @@ +package session + +import "github.com/openkcm/session-manager/internal/credentials" + +type Option func(*Server) + +func WithQueryParametersIntrospect(params []string) Option { + return func(s *Server) { + s.queryParametersIntrospect = params + } +} + +func WithAllowHttpScheme(allow bool) Option { + return func(s *Server) { + s.allowHttpScheme = allow + } +} + +func WithTransportCredentials(b credentials.Builder) Option { + return func(s *Server) { + s.newCreds = b + } +} diff --git a/internal/grpc/session.go b/modules/grpc/session/server.go similarity index 75% rename from internal/grpc/session.go rename to modules/grpc/session/server.go index ab226a3f..312eb26d 100644 --- a/internal/grpc/session.go +++ b/modules/grpc/session/server.go @@ -1,4 +1,4 @@ -package grpc +package session import ( "context" @@ -17,25 +17,26 @@ import ( rpcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/rpc/v1" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" typesv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/types/v1" slogctx "github.com/veqryn/slog-context" grpccodes "google.golang.org/grpc/codes" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/debugtools" - "github.com/openkcm/session-manager/internal/session" - "github.com/openkcm/session-manager/internal/trust" + internalsession "github.com/openkcm/session-manager/internal/session" ) const defaultIntrospectionCacheExpiration = 30 * time.Second var debugSettingSMDumpTransport = debugtools.NewSetting("smdumptransport") -type SessionServer struct { +type Server struct { sessionv1.UnimplementedServiceServer - sessionRepo session.Repository - trustRepo trust.OIDCMappingRepository + sessionRepo internalsession.Repository + trust sessionmanager.Trust newCreds credentials.Builder queryParametersIntrospect []string @@ -47,17 +48,17 @@ type SessionServer struct { introspectionCache *ttlcache.Cache[string, oidc.Introspection] } -func NewSessionServer( +func NewServer( ctx context.Context, - sessionRepo session.Repository, - trustRepo trust.OIDCMappingRepository, + sessionRepo internalsession.Repository, + trust sessionmanager.Trust, idleSessionTimeout time.Duration, clientID string, - opts ...SessionServerOption, -) *SessionServer { - s := &SessionServer{ + opts ...Option, +) *Server { + s := &Server{ sessionRepo: sessionRepo, - trustRepo: trustRepo, + trust: trust, idleSessionTimeout: idleSessionTimeout, newCreds: func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, clientID: clientID, @@ -78,7 +79,7 @@ func NewSessionServer( return s } -func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessionRequest) (*sessionv1.GetSessionResponse, error) { +func (s *Server) GetSession(ctx context.Context, req *sessionv1.GetSessionRequest) (*sessionv1.GetSessionResponse, error) { tracer := otel.GetTracerProvider() ctx, span := tracer.Tracer("").Start(ctx, "get_session") defer span.End() @@ -108,15 +109,15 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio return &sessionv1.GetSessionResponse{Valid: false}, nil } - // Get trust mapping for the given tenant ID - mapping, err := s.trustRepo.Get(ctx, req.GetTenantId()) + // Get trust for the given tenant ID + trust, err := s.trust.Get(ctx, req.GetTenantId()) if err != nil { span.RecordError(err) - span.SetStatus(codes.Error, "failed to get an oidc mapping") - slogctx.Warn(ctx, "Is this an attack? Could not get trust mapping", "issuer", sess.Issuer, "error", err) + span.SetStatus(codes.Error, "failed to get trust") + slogctx.Warn(ctx, "Is this an attack? Could not get trust", "issuer", sess.Issuer, "error", err) return &sessionv1.GetSessionResponse{Valid: false}, nil } - if mapping.Blocked { + if trust.GetBlocked() { slogctx.Warn(ctx, "Tenant is blocked", "issuer", sess.Issuer) span.SetStatus(codes.Ok, "the tenant is blocked") st := status.New(grpccodes.FailedPrecondition, "the tenant is blocked") @@ -163,7 +164,7 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio } // Introspect access token - result, err := s.introspectToken(ctx, sess.AccessToken, &mapping) + result, err := s.introspectToken(ctx, sess.AccessToken, trust.GetOidc()) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "failed to introspect an access token") @@ -193,38 +194,56 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio return response, nil } -func (s *SessionServer) GetOIDCProvider(ctx context.Context, req *sessionv1.GetOIDCProviderRequest) (*sessionv1.GetOIDCProviderResponse, error) { +func (s *Server) GetOIDCProvider(ctx context.Context, req *sessionv1.GetOIDCProviderRequest) (*sessionv1.GetOIDCProviderResponse, error) { tracer := otel.GetTracerProvider() ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_provider") defer span.End() - provider, err := s.trustRepo.Get(ctx, req.GetTenantId()) + provider, err := s.trust.Get(ctx, req.GetTenantId()) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "failed to get an oidc provider") return nil, fmt.Errorf("getting odic provider: %w", err) } + oidc := provider.GetOidc() + span.SetStatus(codes.Ok, "") return &sessionv1.GetOIDCProviderResponse{ Provider: &typesv1.OIDCProvider{ - IssuerUrl: provider.IssuerURL, - JwksUri: provider.JWKSURI, - Audiences: provider.Audiences, + IssuerUrl: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), }, }, nil } -func (s *SessionServer) getClientID(mapping *trust.OIDCMapping) string { - if mapping.ClientID != "" { - return mapping.ClientID +func (s *Server) GetTrust(ctx context.Context, req *sessionv1.GetTrustRequest) (*sessionv1.GetTrustResponse, error) { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "get_trust") + defer span.End() + + trust, err := s.trust.Get(ctx, req.GetTenantId()) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "failed to get an oidc provider") + return nil, fmt.Errorf("getting odic provider: %w", err) + } + + span.SetStatus(codes.Ok, "") + return &sessionv1.GetTrustResponse{Trust: trust}, nil +} + +func (s *Server) getClientID(oidcTrust *oidcv1.OIDC) string { + if clientID := oidcTrust.GetClientId(); clientID != "" { + return clientID } return s.clientID } -func (s *SessionServer) httpClient(mapping *trust.OIDCMapping) *http.Client { - creds := s.newCreds(s.getClientID(mapping)) +func (s *Server) httpClient(oidcTrust *oidcv1.OIDC) *http.Client { + creds := s.newCreds(s.getClientID(oidcTrust)) transport := creds.Transport() if debugSettingSMDumpTransport.Value() == "1" { transport = debugtools.NewTransport(transport) @@ -235,7 +254,7 @@ func (s *SessionServer) httpClient(mapping *trust.OIDCMapping) *http.Client { } } -func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcTrust *trust.OIDCMapping) (oidc.Introspection, error) { +func (s *Server) introspectToken(ctx context.Context, token string, oidcTrust *oidcv1.OIDC) (oidc.Introspection, error) { // first check the cache for a recent introspection result for this token hashedSuffix := sha256.Sum256([]byte(token)) cacheKey := base64.RawURLEncoding.EncodeToString(hashedSuffix[:]) @@ -246,13 +265,12 @@ func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcT httpClient := s.httpClient(oidcTrust) // create the provider for the given issuer - provider, err := oidc.NewProvider(oidcTrust.IssuerURL, oidcTrust.Audiences, - oidc.WithIntrospectQueryParameters(oidcTrust.GetIntrospectParameters(s.queryParametersIntrospect)), + provider, err := oidc.NewProvider(oidcTrust.GetIssuer(), oidcTrust.GetAudiences(), oidc.WithAllowHttpScheme(s.allowHttpScheme), oidc.WithSecureHTTPClient(httpClient), ) if err != nil { - slogctx.Error(ctx, "Could not create OpenID provider", "issuer", oidcTrust.IssuerURL, "error", err) + slogctx.Error(ctx, "Could not create OpenID provider", "issuer", oidcTrust.GetIssuer(), "error", err) return oidc.Introspection{Active: false}, err } @@ -263,7 +281,7 @@ func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcT slogctx.Debug(ctx, "No introspection endpoint configured", "issuer", provider.Issuer) return oidc.Introspection{Active: true}, nil } - slogctx.Error(ctx, "Could not introspect access token", "error", err) + slogctx.Error(ctx, "Could not introspect token", "error", err) return oidc.Introspection{Active: false}, err } diff --git a/internal/grpc/session_test.go b/modules/grpc/session/server_test.go similarity index 69% rename from internal/grpc/session_test.go rename to modules/grpc/session/server_test.go index 2dda10fe..de649b26 100644 --- a/internal/grpc/session_test.go +++ b/modules/grpc/session/server_test.go @@ -1,4 +1,4 @@ -package grpc_test +package session_test import ( "encoding/json" @@ -16,37 +16,40 @@ import ( rpcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/rpc/v1" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/session" + internalsession "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + "github.com/openkcm/session-manager/modules/grpc/session" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) func TestNewSessionServer(t *testing.T) { ctx := t.Context() t.Run("creates server successfully", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, idleSessionTimeout, "") + server := session.NewServer(ctx, sessionRepo, trust, idleSessionTimeout, "") assert.NotNil(t, server) }) t.Run("creates server with options", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, + server := session.NewServer(ctx, sessionRepo, - trustRepo, + trust, idleSessionTimeout, "", - grpc.WithQueryParametersIntrospect([]string{"param1", "param2"}), + session.WithQueryParametersIntrospect([]string{"param1", "param2"}), ) assert.NotNil(t, server) @@ -54,12 +57,13 @@ func TestNewSessionServer(t *testing.T) { t.Run("handles nil option gracefully", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, + server := session.NewServer(ctx, sessionRepo, - trustRepo, + trust, idleSessionTimeout, "", nil, @@ -92,13 +96,13 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-123", TenantID: "tenant-123", Fingerprint: "fingerprint-123", Issuer: testServer.URL, AccessToken: "access-token-123", - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-123", GivenName: "John", FamilyName: "Doe", @@ -108,10 +112,13 @@ func TestGetSession(t *testing.T) { AuthContext: map[string]string{"key": "value"}, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), @@ -119,14 +126,11 @@ func TestGetSession(t *testing.T) { // Mark session as active _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) - req := &sessionv1.GetSessionRequest{ SessionId: "session-123", TenantId: "tenant-123", @@ -168,34 +172,35 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-groups", TenantID: "tenant-groups", Fingerprint: "fingerprint-groups", Issuer: testServer.URL, AccessToken: "access-token-groups", - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-groups", Groups: []string{"session-group1", "session-group2"}, }, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -226,32 +231,33 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-456", TenantID: "tenant-456", Fingerprint: "fingerprint-456", Issuer: testServer.URL, - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-456", }, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -272,9 +278,9 @@ func TestGetSession(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithIsActiveError(isActiveErr), ) - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-123", @@ -290,7 +296,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - session not active", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-789", TenantID: "tenant-789", Fingerprint: "fingerprint-789", @@ -301,9 +307,9 @@ func TestGetSession(t *testing.T) { ) // Don't bump active - session is not active - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-789", @@ -323,14 +329,14 @@ func TestGetSession(t *testing.T) { sessionmock.WithLoadSessionError(errors.New("load error")), ) // Create a session and mark as active but LoadSession will error - sess := session.Session{ID: "session-fail"} + sess := internalsession.Session{ID: "session-fail"} err := sessionRepo.StoreSession(ctx, sess) assert.NoError(t, err) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fail", @@ -345,8 +351,8 @@ func TestGetSession(t *testing.T) { assert.False(t, resp.GetValid()) }) - t.Run("invalid - trust mapping not found", func(t *testing.T) { - sess := session.Session{ + t.Run("invalid - trust not found", func(t *testing.T) { + sess := internalsession.Session{ ID: "session-no-provider", TenantID: "tenant-no-provider", Fingerprint: "fingerprint-123", @@ -358,10 +364,10 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - // No mapping added to repo - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + // No trust added to repo + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-no-provider", @@ -376,8 +382,8 @@ func TestGetSession(t *testing.T) { assert.False(t, resp.GetValid()) }) - t.Run("invalid - trust mapping is blocked", func(t *testing.T) { - sess := session.Session{ + t.Run("invalid - trust is blocked", func(t *testing.T) { + sess := internalsession.Session{ ID: "session-blocked", TenantID: "tenant-blocked", Fingerprint: "fingerprint-123", @@ -389,15 +395,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, // Mapping is blocked - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-blocked", @@ -425,7 +433,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - fingerprint mismatch", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-fingerprint", TenantID: "tenant-fingerprint", Fingerprint: "correct-fingerprint", @@ -437,15 +445,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fingerprint", @@ -461,7 +471,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - tenant ID mismatch", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-tenant", TenantID: "correct-tenant", Fingerprint: "fingerprint-123", @@ -473,15 +483,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust("wrong-tenant", mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new("wrong-tenant"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-tenant", @@ -497,7 +509,7 @@ func TestGetSession(t *testing.T) { }) t.Run("error - GetOpenIDConfig fails", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-config-fail", TenantID: "tenant-config-fail", Fingerprint: "fingerprint-123", @@ -509,15 +521,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://invalid-issuer-no-server.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://invalid-issuer-no-server.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-config-fail", @@ -548,7 +562,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-introspect-fail", TenantID: "tenant-introspect-fail", Fingerprint: "fingerprint-123", @@ -561,16 +575,18 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -604,7 +620,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-inactive-token", TenantID: "tenant-inactive-token", Fingerprint: "fingerprint-123", @@ -617,16 +633,18 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -653,7 +671,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-bump-fail", TenantID: "tenant-bump-fail", Fingerprint: "fingerprint-123", @@ -666,16 +684,18 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + trustData := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -696,21 +716,16 @@ func TestWithQueryParametersIntrospect(t *testing.T) { ctx := t.Context() t.Run("sets query parameters correctly", func(t *testing.T) { params := []string{"param1", "param2", "param3"} - opt := grpc.WithQueryParametersIntrospect(params) + opt := session.WithQueryParametersIntrospect(params) assert.NotNil(t, opt) // Test that the option actually sets the parameters sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, - sessionRepo, - trustRepo, - 90*time.Minute, - "", - opt, - ) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", opt) assert.NotNil(t, server) }) @@ -720,18 +735,20 @@ func TestGetOIDCProvider(t *testing.T) { ctx := t.Context() t.Run("success - returns OIDC provider", func(t *testing.T) { - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - JWKSURI: "https://issuer.example.com/.well-known/jwks.json", - Audiences: []string{"audience1", "audience2"}, - } - + trustData := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new("https://issuer.example.com/.well-known/jwks.json"), + Audiences: []string{"audience1", "audience2"}, + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) + trust := newTrust(trustRepo) sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", mapping), - ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", @@ -749,10 +766,9 @@ func TestGetOIDCProvider(t *testing.T) { t.Run("error - provider not found", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") - + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "non-existent-tenant", } @@ -766,12 +782,11 @@ func TestGetOIDCProvider(t *testing.T) { t.Run("error - repository returns error", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository( - trustmock.WithGetError(errors.New("database connection error")), + trustRepo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(errors.New("database connection error")), ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") - + trust := newTrust(trustRepo) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", } diff --git a/internal/grpc/violations.go b/modules/grpc/session/violations.go similarity index 77% rename from internal/grpc/violations.go rename to modules/grpc/session/violations.go index 5479ea89..b4313e38 100644 --- a/internal/grpc/violations.go +++ b/modules/grpc/session/violations.go @@ -1,4 +1,4 @@ -package grpc +package session const ( violationTenantBlocked = "tenant_blocked" diff --git a/modules/grpc/trustmapping/import_test.go b/modules/grpc/trustmapping/import_test.go new file mode 100644 index 00000000..958539cf --- /dev/null +++ b/modules/grpc/trustmapping/import_test.go @@ -0,0 +1,12 @@ +package trustmapping_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/modules/grpc/trustmapping/module.go b/modules/grpc/trustmapping/module.go new file mode 100644 index 00000000..bc949a95 --- /dev/null +++ b/modules/grpc/trustmapping/module.go @@ -0,0 +1,60 @@ +// Package trustmapping provides the service.module.grpc.trustmapping module: +// a gRPC service module that registers the +// kms.api.cmk.sessionmanager.trustmapping.v1.Service proto onto a +// grpc.ServiceRegistrar supplied by app.module.grpcserver. +package trustmapping + +import ( + "fmt" + + "google.golang.org/grpc" + + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "service.module.grpc.trustmapping" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the service.module.grpc.trustmapping module. It owns a Server +// that implements the trustmapping proto and resolves its single dependency +// (a sessionmanager.Trust implementation) by ID via ctx.GetModule. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + m.server = NewServer(trust) + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + trustmappingv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/trustmapping/module_test.go b/modules/grpc/trustmapping/module_test.go new file mode 100644 index 00000000..33923884 --- /dev/null +++ b/modules/grpc/trustmapping/module_test.go @@ -0,0 +1,83 @@ +package trustmapping_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + tmmod "github.com/openkcm/session-manager/modules/grpc/trustmapping" + _ "github.com/openkcm/session-manager/modules/standard" +) + +// stubTrust satisfies sessionmanager.Trust enough for Provision to type-assert +// successfully. No method bodies need to be exercised in these tests. +type stubTrust struct{} + +func (stubTrust) Apply(context.Context, *trustv1.Trust) error { return nil } +func (stubTrust) Block(context.Context, string) error { return nil } +func (stubTrust) Remove(context.Context, string) error { return nil } +func (stubTrust) Unblock(context.Context, string) error { return nil } + +var errStubTrustGet = errors.New("stub get not implemented") + +func (stubTrust) Get(context.Context, string) (*trustv1.Trust, error) { + return nil, errStubTrustGet +} + +// stubTrustModule lets us register a fake trust module with the registry under +// a custom ID for tests. +type stubTrustModule struct { + stubTrust + + id string +} + +func (s *stubTrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return s }, + } +} + +func TestModule_Registration(t *testing.T) { + info, err := sessionmanager.GetModule("service.module.grpc.trustmapping") + require.NoError(t, err) + assert.Equal(t, "service.module.grpc.trustmapping", info.ID) +} + +func TestModule_ProvisionResolvesCustomTrust(t *testing.T) { + id := "trust.module.test." + t.Name() + sessionmanager.RegisterModule(&stubTrustModule{id: id}) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&extConfig{moduleID: id}) + require.NoError(t, err) + + m := &tmmod.Module{Trust: id} + require.NoError(t, m.Provision(ctx)) +} + +func TestModule_ProvisionMissingTrustFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := &tmmod.Module{Trust: "no.such.trust.module"} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "no.such.trust.module") +} + +// extConfig is a minimal sessionmanager.ExtensionConfig used by the test to +// load fake modules through the registry. +type extConfig struct{ moduleID string } + +func (c *extConfig) Module() string { return c.moduleID } +func (c *extConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } diff --git a/modules/grpc/trustmapping/server.go b/modules/grpc/trustmapping/server.go new file mode 100644 index 00000000..0359312b --- /dev/null +++ b/modules/grpc/trustmapping/server.go @@ -0,0 +1,133 @@ +package trustmapping + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +type Server struct { + trustmappingv1.UnimplementedServiceServer + + trust sessionmanager.Trust +} + +func NewServer(trust sessionmanager.Trust) *Server { + return &Server{trust: trust} +} + +func (srv *Server) ApplyTrustMapping(ctx context.Context, in *trustmappingv1.ApplyTrustMappingRequest) (*trustmappingv1.ApplyTrustMappingResponse, error) { + oidcIn := in.GetOidc() + oidc := oidcv1.OIDC_builder{ + Issuer: new(oidcIn.GetIssuer()), + JwksUri: new(oidcIn.GetJwksUri()), + Audiences: oidcIn.GetAudiences(), + ClientId: new(oidcIn.GetClientId()), + }.Build() + + trust := trustv1.Trust_builder{ + TenantId: new(in.GetTenantId()), + Oidc: oidc, + }.Build() + + ctx = slogctx.With(ctx, + "tenantId", trust.GetTenantId(), + "issuer", oidc.GetIssuer(), + "jwksUri", oidc.GetJwksUri(), + "audiences", oidc.GetAudiences(), + "client_id", oidc.GetClientId(), + ) + + slogctx.Debug(ctx, "ApplyTrustMapping called") + + response := trustmappingv1.ApplyTrustMappingResponse_builder{}.Build() + + if err := srv.trust.Apply(ctx, trust); err != nil { + slogctx.Error(ctx, "Could not apply trust", "error", err) + if errors.Is(err, serviceerr.ErrNotFound) { + msg := serviceerr.ErrNotFound.Error() + response.SetMessage(msg) + return response, nil + } + + return nil, status.Errorf(codes.Internal, "failed to apply trust: %v", err) + } + + response.SetSuccess(true) + + return response, nil +} + +// BlockTrustMapping blocks the trust for the specified tenant. +// It calls the underlying service to set the trust as blocked. +// Returns a response containing an optional error message if blocking fails. +func (srv *Server) BlockTrustMapping(ctx context.Context, req *trustmappingv1.BlockTrustMappingRequest) (*trustmappingv1.BlockTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "BlockTrustMapping called") + + resp := trustmappingv1.BlockTrustMappingResponse_builder{}.Build() + err := srv.trust.Block(ctx, req.GetTenantId()) + if err != nil { + slogctx.Error(ctx, "Could not block trust", "error", err) + msg := err.Error() + + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to block trust: "+msg) + } + + resp.SetSuccess(true) + return resp, nil +} + +// RemoveTrustMapping removes the trust configuration for the tenant. +// It calls the underlying service to remove the trust. +// Returns a respose containing an optional error message if removing fails. +func (srv *Server) RemoveTrustMapping(ctx context.Context, req *trustmappingv1.RemoveTrustMappingRequest) (*trustmappingv1.RemoveTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "RemoveTrustMapping called") + + resp := &trustmappingv1.RemoveTrustMappingResponse{} + err := srv.trust.Remove(ctx, req.GetTenantId()) + if err != nil { + if !errors.Is(err, serviceerr.ErrNotFound) { + slogctx.Error(ctx, "Could not remove trust", "error", err) + msg := err.Error() + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to remove trust: "+msg) + } + slogctx.Warn(ctx, "RemoveTrustMapping is called but the tenant does not exist", "error", err) + } + + resp.SetSuccess(true) + return resp, nil +} + +// UnblockTrustMapping unblocks the trust for the specified tenant. +// It calls the underlying service to set the trust as unblocked. +// Returns a response containing an optional error message if unblocking fails. +func (srv *Server) UnblockTrustMapping(ctx context.Context, req *trustmappingv1.UnblockTrustMappingRequest) (*trustmappingv1.UnblockTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "UnblockTrustMapping called") + + resp := &trustmappingv1.UnblockTrustMappingResponse{} + err := srv.trust.Unblock(ctx, req.GetTenantId()) + if err != nil { + slogctx.Error(ctx, "Could not unblock trust", "error", err) + msg := err.Error() + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to unblock trust: "+msg) + } + + resp.SetSuccess(true) + return resp, nil +} diff --git a/modules/grpc/trustmapping/server_test.go b/modules/grpc/trustmapping/server_test.go new file mode 100644 index 00000000..c15bb74e --- /dev/null +++ b/modules/grpc/trustmapping/server_test.go @@ -0,0 +1,452 @@ +package trustmapping_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/modules/grpc/trustmapping" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +func TestNewTrustMappingServer(t *testing.T) { + t.Run("creates server successfully", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + assert.NotNil(t, server) + }) +} + +func TestApplyTrustMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success - creates new trust", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + jwksUri := "https://issuer.example.com/.well-known/jwks.json" + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: &jwksUri, + Audiences: []string{"audience1", "audience2"}, + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("success - updates existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://old-issuer.example.com"), + JwksUri: new("https://old-issuer.example.com/jwks.json"), + Audiences: []string{"old-audience"}, + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + jwksUri := "https://new-issuer.example.com/jwks.json" + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://new-issuer.example.com"), + JwksUri: new(jwksUri), + Audiences: []string{"new-audience"}, + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("not found error - returns response with message", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + jwksUri := "https://issuer.example.com/jwks.json" + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.GetSuccess()) + require.NotNil(t, resp.GetMessage()) + assert.Equal(t, serviceerr.ErrNotFound.Error(), resp.GetMessage()) + }) + + t.Run("internal error - returns grpc error", func(t *testing.T) { + internalErr := errors.New("database connection failed") + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(internalErr), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + jwksUri := "https://issuer.example.com/jwks.json" + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) + + assert.Nil(t, resp) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to apply trust") + }) + + t.Run("update error - returns grpc error", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + updateErr := errors.New("update failed") + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + mocktrust.WithUpdateError(updateErr), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + jwksUri := "https://new-issuer.example.com/jwks.json" + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://new-issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) + + assert.Nil(t, resp) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + }) +} + +func TestBlockTrustMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success - blocks existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.BlockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("success - already blocked", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.BlockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("not found - returns success", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.BlockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("error - returns grpc error with message", func(t *testing.T) { + internalErr := errors.New("database error") + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(internalErr), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.BlockTrustMapping(ctx, req) + + require.Error(t, err) + assert.NotNil(t, resp) + require.NotNil(t, resp.GetMessage()) + assert.Contains(t, resp.GetMessage(), "database error") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to block trust") + }) +} + +func TestRemoveTrustMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success - removes existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.RemoveTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("error - returns grpc error with message", func(t *testing.T) { + deleteErr := errors.New("delete failed") + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(deleteErr), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.RemoveTrustMapping(ctx, req) + + require.Error(t, err) + assert.NotNil(t, resp) + require.NotNil(t, resp.GetMessage()) + assert.Contains(t, resp.GetMessage(), "delete failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to remove trust") + }) + + t.Run("error - delete is indempotent", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.RemoveTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) +} + +func TestUnblockTrustMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success - unblocks blocked trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.UnblockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("success - already unblocked", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.UnblockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("not found - returns success", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.UnblockTrustMapping(ctx, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("error - returns grpc error with message", func(t *testing.T) { + internalErr := errors.New("update failed") + existingTrust := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingTrust), + mocktrust.WithUpdateError(internalErr), + ) + svc := newTrust(repo) + server := trustmapping.NewServer(svc) + + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() + + resp, err := server.UnblockTrustMapping(ctx, req) + + require.Error(t, err) + assert.NotNil(t, resp) + require.NotNil(t, resp.GetMessage()) + assert.Contains(t, resp.GetMessage(), "update failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to unblock trust") + }) +} diff --git a/modules/oidctrust/export.go b/modules/oidctrust/export.go new file mode 100644 index 00000000..cb30b5cb --- /dev/null +++ b/modules/oidctrust/export.go @@ -0,0 +1,15 @@ +package oidctrust + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" +) + +//nolint:unused +//go:linkname newOIDCTrustModuleWithRepo +func newOIDCTrustModuleWithRepo(r TrustRepository) sessionmanager.Trust { + return &TrustModule{ + repository: r, + } +} diff --git a/modules/oidctrust/export_test.go b/modules/oidctrust/export_test.go new file mode 100644 index 00000000..c4cb26ac --- /dev/null +++ b/modules/oidctrust/export_test.go @@ -0,0 +1,7 @@ +package oidctrust + +func NewModule(repo TrustRepository) *TrustModule { + return &TrustModule{ + repository: repo, + } +} diff --git a/modules/oidctrust/internal/sql/export_test.go b/modules/oidctrust/internal/sql/export_test.go new file mode 100644 index 00000000..1ec027ca --- /dev/null +++ b/modules/oidctrust/internal/sql/export_test.go @@ -0,0 +1,15 @@ +package sqltrust + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// PgTextOrNull exposes pgTextOrNull for testing. +func PgTextOrNull(s string) pgtype.Text { + return pgTextOrNull(s) +} + +// HandlePgError exposes handlePgError for testing. +func HandlePgError(err error) (error, bool) { + return handlePgError(err) +} diff --git a/internal/trust/trustsql/queries.sql b/modules/oidctrust/internal/sql/queries.sql similarity index 76% rename from internal/trust/trustsql/queries.sql rename to modules/oidctrust/internal/sql/queries.sql index 93e07ea6..a50560fe 100644 --- a/internal/trust/trustsql/queries.sql +++ b/modules/oidctrust/internal/sql/queries.sql @@ -1,22 +1,20 @@ --- name: GetOIDCMapping :one +-- name: GetTrust :one SELECT issuer, blocked, jwks_uri, audiences, - properties, client_id FROM trust WHERE tenant_id = sqlc.arg(tenant_id); --- name: CreateOIDCMapping :exec +-- name: CreateTrust :exec INSERT INTO trust ( tenant_id, blocked, issuer, jwks_uri, audiences, - properties, client_id) VALUES ( sqlc.arg(tenant_id), @@ -24,21 +22,19 @@ VALUES ( sqlc.arg(issuer), sqlc.arg(jwks_uri), COALESCE(sqlc.arg(audiences)::text[], '{}'::text[]), - sqlc.arg(properties), sqlc.arg(client_id)); --- name: DeleteOIDCMapping :execrows +-- name: DeleteTrust :execrows DELETE FROM trust WHERE tenant_id = sqlc.arg(tenant_id); --- name: UpdateOIDCMapping :execrows +-- name: UpdateTrust :execrows UPDATE trust SET blocked = sqlc.arg(blocked), issuer = sqlc.arg(issuer), jwks_uri = sqlc.arg(jwks_uri), audiences = COALESCE(sqlc.arg(audiences)::text[], '{}'::text[]), - properties = sqlc.arg(properties), client_id = sqlc.arg(client_id) WHERE tenant_id = sqlc.arg(tenant_id); diff --git a/internal/trust/trustsql/internal/queries/db.go b/modules/oidctrust/internal/sql/queries/db.go similarity index 96% rename from internal/trust/trustsql/internal/queries/db.go rename to modules/oidctrust/internal/sql/queries/db.go index c69f0c53..2b5c1c72 100644 --- a/internal/trust/trustsql/internal/queries/db.go +++ b/modules/oidctrust/internal/sql/queries/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package queries diff --git a/modules/oidctrust/internal/sql/queries/models.go b/modules/oidctrust/internal/sql/queries/models.go new file mode 100644 index 00000000..77a9f77e --- /dev/null +++ b/modules/oidctrust/internal/sql/queries/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package queries + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Trust struct { + TenantID string `db:"tenant_id"` + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + CreatedAt pgtype.Timestamp `db:"created_at"` + ClientID pgtype.Text `db:"client_id"` +} diff --git a/modules/oidctrust/internal/sql/queries/queries.sql.go b/modules/oidctrust/internal/sql/queries/queries.sql.go new file mode 100644 index 00000000..ec9a845c --- /dev/null +++ b/modules/oidctrust/internal/sql/queries/queries.sql.go @@ -0,0 +1,131 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: queries.sql + +package queries + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createTrust = `-- name: CreateTrust :exec +INSERT INTO trust ( + tenant_id, + blocked, + issuer, + jwks_uri, + audiences, + client_id) +VALUES ( + $1, + $2, + $3, + $4, + COALESCE($5::text[], '{}'::text[]), + $6) +` + +type CreateTrustParams struct { + TenantID string `db:"tenant_id"` + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` +} + +func (q *Queries) CreateTrust(ctx context.Context, arg CreateTrustParams) error { + _, err := q.db.Exec(ctx, createTrust, + arg.TenantID, + arg.Blocked, + arg.Issuer, + arg.JwksUri, + arg.Audiences, + arg.ClientID, + ) + return err +} + +const deleteTrust = `-- name: DeleteTrust :execrows +DELETE FROM trust +WHERE tenant_id = $1 +` + +func (q *Queries) DeleteTrust(ctx context.Context, tenantID string) (int64, error) { + result, err := q.db.Exec(ctx, deleteTrust, tenantID) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + +const getTrust = `-- name: GetTrust :one +SELECT + issuer, + blocked, + jwks_uri, + audiences, + client_id +FROM trust +WHERE tenant_id = $1 +` + +type GetTrustRow struct { + Issuer string `db:"issuer"` + Blocked bool `db:"blocked"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` +} + +func (q *Queries) GetTrust(ctx context.Context, tenantID string) (GetTrustRow, error) { + row := q.db.QueryRow(ctx, getTrust, tenantID) + var i GetTrustRow + err := row.Scan( + &i.Issuer, + &i.Blocked, + &i.JwksUri, + &i.Audiences, + &i.ClientID, + ) + return i, err +} + +const updateTrust = `-- name: UpdateTrust :execrows +UPDATE trust +SET + blocked = $1, + issuer = $2, + jwks_uri = $3, + audiences = COALESCE($4::text[], '{}'::text[]), + client_id = $5 +WHERE + tenant_id = $6 +` + +type UpdateTrustParams struct { + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` + TenantID string `db:"tenant_id"` +} + +func (q *Queries) UpdateTrust(ctx context.Context, arg UpdateTrustParams) (int64, error) { + result, err := q.db.Exec(ctx, updateTrust, + arg.Blocked, + arg.Issuer, + arg.JwksUri, + arg.Audiences, + arg.ClientID, + arg.TenantID, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} diff --git a/modules/oidctrust/internal/sql/sql.go b/modules/oidctrust/internal/sql/sql.go new file mode 100644 index 00000000..1e98da60 --- /dev/null +++ b/modules/oidctrust/internal/sql/sql.go @@ -0,0 +1,154 @@ +package sqltrust + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "go.opentelemetry.io/otel" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust/internal/sql/queries" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +type Repository struct { + queries *queries.Queries +} + +func NewRepository(db sessionmanager.Database) *Repository { + return &Repository{ + queries: queries.New(db), + } +} + +func (r *Repository) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "get_trust_sql") + defer span.End() + + row, err := r.queries.GetTrust(ctx, tenantID) + if err != nil { + span.RecordError(err) + if errors.Is(err, pgx.ErrNoRows) { + return nil, serviceerr.ErrNotFound + } + + return nil, err + } + + trust := trustv1.Trust_builder{ + TenantId: &tenantID, + Blocked: &row.Blocked, + Oidc: oidcv1.OIDC_builder{ + Audiences: row.Audiences, + }.Build(), + }.Build() + + if row.Issuer != "" { + trust.GetOidc().SetIssuer(row.Issuer) + } + + if row.JwksUri != "" { + trust.GetOidc().SetJwksUri(row.JwksUri) + } + + if row.ClientID.Valid { + trust.GetOidc().SetClientId(row.ClientID.String) + } + + return trust, nil +} + +func (r *Repository) Create(ctx context.Context, trust *trustv1.Trust) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "create_trust_sql") + defer span.End() + + oidc := trust.GetOidc() + + if err := r.queries.CreateTrust(ctx, queries.CreateTrustParams{ + TenantID: trust.GetTenantId(), + Blocked: trust.GetBlocked(), + Issuer: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), + ClientID: pgTextOrNull(trust.GetOidc().GetClientId()), + }); err != nil { + span.RecordError(err) + if err, ok := handlePgError(err); ok { + return err + } + + return fmt.Errorf("inserting into trust: %w", err) + } + + return nil +} + +func (r *Repository) Delete(ctx context.Context, tenantID string) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "delete_trust_sql") + defer span.End() + + affected, err := r.queries.DeleteTrust(ctx, tenantID) + if err != nil { + span.RecordError(err) + return fmt.Errorf("executing sql query: %w", err) + } + + if affected == 0 { + return serviceerr.ErrNotFound + } + + return nil +} + +func (r *Repository) Update(ctx context.Context, trust *trustv1.Trust) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "update_trust_sql") + defer span.End() + + oidc := trust.GetOidc() + + affected, err := r.queries.UpdateTrust(ctx, queries.UpdateTrustParams{ + Blocked: trust.GetBlocked(), + Issuer: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), + ClientID: pgTextOrNull(oidc.GetClientId()), + TenantID: trust.GetTenantId(), + }) + if err != nil { + span.RecordError(err) + return fmt.Errorf("updating trust: %w", err) + } + + if affected == 0 { + return serviceerr.ErrNotFound + } + + return nil +} + +func pgTextOrNull(s string) pgtype.Text { + return pgtype.Text{ + String: s, + Valid: s != "", + } +} + +func handlePgError(err error) (error, bool) { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return serviceerr.ErrConflict, true + } + + return err, false +} diff --git a/modules/oidctrust/internal/sql/sql_test.go b/modules/oidctrust/internal/sql/sql_test.go new file mode 100644 index 00000000..45650388 --- /dev/null +++ b/modules/oidctrust/internal/sql/sql_test.go @@ -0,0 +1,305 @@ +package sqltrust_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/dbtest/postgrestest" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +var dbPool sessionmanager.Database + +type pooldb struct { + *pgxpool.Pool +} + +func (p *pooldb) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(p.Pool) +} + +func TestMain(m *testing.M) { + ctx := context.Background() + + pool, _, terminate := postgrestest.Start(ctx) + defer terminate(ctx) + + dbPool = &pooldb{pool} + + code := m.Run() + os.Exit(code) +} + +func TestRepository_Get(t *testing.T) { + tests := []struct { + name string + tenantID string + wantTrust *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Success", + tenantID: "tenant1-id", + wantTrust: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), Audiences: make([]string, 0)}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Error does not exist", + tenantID: "does-not-exist", + assertErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := sqltrust.NewRepository(dbPool) + + gotTrust, err := r.Get(t.Context(), tt.tenantID) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Get() error %v", err)) || err != nil { + assert.Zerof(t, gotTrust, "Repository.Get() extected zero value if an error is returned, got %v", gotTrust) + return + } + + if diff := cmp.Diff(tt.wantTrust, gotTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) + } + }) + } +} + +func TestRepository_Create(t *testing.T) { + tests := []struct { + name string + trust *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Create succeeds", + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Duplicate", + trust: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.Error, + }, + { + name: "Create without JWKSURI and Audiences succeeds", + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-2.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Create without JWKSURI succeeds", + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-3.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Create without Audiences succeeds", + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-4.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + r := sqltrust.NewRepository(dbPool) + + // When + err := r.Create(t.Context(), tt.trust) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Create() error %v", err)) || err != nil { + return + } + + // Then + gotCreated, err := r.Get(t.Context(), tt.trust.GetTenantId()) + require.NoError(t, err) + + if diff := cmp.Diff(tt.trust, gotCreated, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected trust in the database (-want, +got):\n%s", diff) + } + }) + } +} + +func TestRepository_Delete(t *testing.T) { + const tenantID = "tenant-id-delete-success" + trust := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-delete.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + r := sqltrust.NewRepository(dbPool) + err := r.Create(t.Context(), trust) + require.NoError(t, err, "Inserting test data") + + tests := []struct { + name string + tenantID string + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Delete tenant", + tenantID: tenantID, + assertErr: assert.NoError, + }, + { + name: "Error does not exist", + tenantID: "does-not-exist", + assertErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.Delete(t.Context(), tt.tenantID) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Delete() error %v", err)) || err != nil { + return + } + + gotTrust, err := r.Get(t.Context(), tt.tenantID) + if !errors.Is(err, serviceerr.ErrNotFound) { + t.Error("The trust is expected to be deleted") + } + assert.Zero(t, gotTrust, "The trust is expected to be deleted, instead a value is returned") + }) + } +} + +func TestRepository_Update(t *testing.T) { + const tenantID = "tenant-id-update-success" + trust := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-update.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + r := sqltrust.NewRepository(dbPool) + err := r.Create(t.Context(), trust) + require.NoError(t, err, "Inserting test data") + + tests := []struct { + name string + trust *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Update succeeds", + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Does not exist", + trust: trustv1.Trust_builder{TenantId: new("does-not-exist"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new("does-not-exist"), JwksUri: new("jwks-updated.example.com"), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.Error, + }, + { + name: "Update without JWKSURI and Audiences succeeds", + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Update without JWKSURI succeeds", + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Update without Audiences succeeds", + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.Update(t.Context(), tt.trust) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Update() error %v", err)) || err != nil { + return + } + + gotTrust, err := r.Get(t.Context(), tt.trust.GetTenantId()) + require.NoError(t, err) + + if diff := cmp.Diff(tt.trust, gotTrust, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected trust in the database (-want, +got):\n%s", diff) + } + }) + } +} + +func TestPgTextOrNull(t *testing.T) { + tests := []struct { + name string + input string + want pgtype.Text + }{ + { + name: "empty string returns invalid (null)", + input: "", + want: pgtype.Text{String: "", Valid: false}, + }, + { + name: "non-empty string returns valid text", + input: "hello", + want: pgtype.Text{String: "hello", Valid: true}, + }, + { + name: "whitespace-only string returns valid text", + input: " ", + want: pgtype.Text{String: " ", Valid: true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sqltrust.PgTextOrNull(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestHandlePgError(t *testing.T) { + otherPgErr := &pgconn.PgError{Code: "42P01"} // undefined_table + sentinel := errors.New("some other error") + + tests := []struct { + name string + err error + wantErr error + wantHandled bool + }{ + { + name: "duplicate key violation (23505) returns ErrConflict", + err: &pgconn.PgError{Code: "23505"}, + wantErr: serviceerr.ErrConflict, + wantHandled: true, + }, + { + name: "other pg error code returns original error", + err: otherPgErr, + wantErr: otherPgErr, + wantHandled: false, + }, + { + name: "non-pg error returns original error", + err: sentinel, + wantErr: sentinel, + wantHandled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, handled := sqltrust.HandlePgError(tt.err) + if handled != tt.wantHandled { + t.Errorf("handled = %v, want %v", handled, tt.wantHandled) + } + if !errors.Is(got, tt.wantErr) { + t.Errorf("err = %v, want %v", got, tt.wantErr) + } + }) + } +} diff --git a/sql/00001_init.sql b/modules/oidctrust/migrations/00001_init.sql similarity index 100% rename from sql/00001_init.sql rename to modules/oidctrust/migrations/00001_init.sql diff --git a/sql/00002_add_properties_to_providers.sql b/modules/oidctrust/migrations/00002_add_properties_to_providers.sql similarity index 100% rename from sql/00002_add_properties_to_providers.sql rename to modules/oidctrust/migrations/00002_add_properties_to_providers.sql diff --git a/sql/00003_tenant_trust.sql b/modules/oidctrust/migrations/00003_tenant_trust.sql similarity index 100% rename from sql/00003_tenant_trust.sql rename to modules/oidctrust/migrations/00003_tenant_trust.sql diff --git a/sql/00004_single_tenant.sql b/modules/oidctrust/migrations/00004_single_tenant.sql similarity index 100% rename from sql/00004_single_tenant.sql rename to modules/oidctrust/migrations/00004_single_tenant.sql diff --git a/modules/oidctrust/migrations/00005_remove_properties.sql b/modules/oidctrust/migrations/00005_remove_properties.sql new file mode 100644 index 00000000..76d2041f --- /dev/null +++ b/modules/oidctrust/migrations/00005_remove_properties.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE trust + DROP COLUMN properties; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE trust + ADD COLUMN properties JSONB NOT NULL; +-- +goose StatementEnd diff --git a/modules/oidctrust/migrations/migration.go b/modules/oidctrust/migrations/migration.go new file mode 100644 index 00000000..e378c782 --- /dev/null +++ b/modules/oidctrust/migrations/migration.go @@ -0,0 +1,62 @@ +package migrations + +import ( + "context" + "embed" + "fmt" + + "github.com/pressly/goose/v3" + + sessionmanager "github.com/openkcm/session-manager" +) + +//go:embed *.sql +var FS embed.FS + +const moduleID = "trust.migration.module.oidc" + +func newModule() sessionmanager.Module { + return new(MigrationModule) +} + +func init() { + sessionmanager.RegisterModule(new(MigrationModule)) +} + +type MigrationModule struct { + DBModule string `yaml:"dbModule" default:"database.module.pgxpool"` + + db sessionmanager.Database +} + +func (m *MigrationModule) Migrate(ctx context.Context) error { + goose.SetBaseFS(FS) + + if err := goose.SetDialect("postgres"); err != nil { + return fmt.Errorf("setting goose dialect: %w", err) + } + + if err := goose.UpContext(ctx, m.db.STDAdapter(), "."); err != nil { + return fmt.Errorf("applying migrations: %w", err) + } + + return nil +} + +func (m *MigrationModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *MigrationModule) Provision(ctx *sessionmanager.Context) error { + mod, err := ctx.GetModule(m.DBModule) + if err != nil { + return fmt.Errorf("getting postgres module: %w", err) + } + + //nolint:forcetypeassert + m.db = mod.(sessionmanager.Database) + return nil +} diff --git a/internal/trust/trustmock/repository.go b/modules/oidctrust/mocks/repository.go similarity index 53% rename from internal/trust/trustmock/repository.go rename to modules/oidctrust/mocks/repository.go index d64d6cca..e9774385 100644 --- a/internal/trust/trustmock/repository.go +++ b/modules/oidctrust/mocks/repository.go @@ -1,22 +1,24 @@ -package trustmock +package mocktrust import ( "context" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/modules/oidctrust" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type RepositoryOption func(*Repository) type Repository struct { - tenantTrust map[string]trust.OIDCMapping + tenantTrust map[string]*trustv1.Trust getErr, createErr, deleteErr, updateErr error } -func WithTrust(tenantID string, mapping trust.OIDCMapping) RepositoryOption { - return func(r *Repository) { r.tenantTrust[tenantID] = mapping } +func WithTrust(trust *trustv1.Trust) RepositoryOption { + return func(r *Repository) { r.tenantTrust[trust.GetTenantId()] = trust } } func WithGetError(err error) RepositoryOption { return func(r *Repository) { r.getErr = err } @@ -31,11 +33,11 @@ func WithUpdateError(err error) RepositoryOption { return func(r *Repository) { r.updateErr = err } } -var _ = trust.OIDCMappingRepository(&Repository{}) +var _ oidctrust.TrustRepository = (*Repository)(nil) func NewInMemRepository(opts ...RepositoryOption) *Repository { r := &Repository{ - tenantTrust: make(map[string]trust.OIDCMapping), + tenantTrust: make(map[string]*trustv1.Trust), } for _, opt := range opts { if opt != nil { @@ -46,30 +48,30 @@ func NewInMemRepository(opts ...RepositoryOption) *Repository { } // TAdd is a helper method for tests to add a trust relationship. -func (r *Repository) TAdd(tenantID string, mapping trust.OIDCMapping) { - r.tenantTrust[tenantID] = mapping +func (r *Repository) TAdd(trust *trustv1.Trust) { + r.tenantTrust[trust.GetTenantId()] = trust } // TGet is a helper method for tests to get a trust relationship. -func (r *Repository) TGet(tenantID string) trust.OIDCMapping { +func (r *Repository) TGet(tenantID string) *trustv1.Trust { return r.tenantTrust[tenantID] } -func (r *Repository) Get(_ context.Context, tenantID string) (trust.OIDCMapping, error) { +func (r *Repository) Get(_ context.Context, tenantID string) (*trustv1.Trust, error) { if r.getErr != nil { - return trust.OIDCMapping{}, r.getErr + return nil, r.getErr } - if mapping, ok := r.tenantTrust[tenantID]; ok { - return mapping, nil + if trust, ok := r.tenantTrust[tenantID]; ok { + return trust, nil } - return trust.OIDCMapping{}, serviceerr.ErrNotFound + return nil, serviceerr.ErrNotFound } -func (r *Repository) Create(_ context.Context, tenantID string, mapping trust.OIDCMapping) error { +func (r *Repository) Create(_ context.Context, trust *trustv1.Trust) error { if r.createErr != nil { return r.createErr } - r.tenantTrust[tenantID] = mapping + r.tenantTrust[trust.GetTenantId()] = trust return nil } @@ -84,10 +86,10 @@ func (r *Repository) Delete(_ context.Context, tenantID string) error { return nil } -func (r *Repository) Update(_ context.Context, tenantID string, mapping trust.OIDCMapping) error { +func (r *Repository) Update(_ context.Context, trust *trustv1.Trust) error { if r.updateErr != nil { return r.updateErr } - r.tenantTrust[tenantID] = mapping + r.tenantTrust[trust.GetTenantId()] = trust return nil } diff --git a/modules/oidctrust/module.go b/modules/oidctrust/module.go new file mode 100644 index 00000000..cc209b13 --- /dev/null +++ b/modules/oidctrust/module.go @@ -0,0 +1,48 @@ +package oidctrust + +import ( + "fmt" + + sessionmanager "github.com/openkcm/session-manager" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" +) + +const moduleID = "trust.module.oidc" + +func newModule() sessionmanager.Module { + return new(TrustModule) +} + +func init() { + sessionmanager.RegisterModule(new(TrustModule)) +} + +// TrustModule is a module that implements sessionmanager.Trust interface. It's using a database provided by the +// [dbModule] module which implements sessionmanager.DBModule. +type TrustModule struct { + DBModule string `yaml:"dbModule" default:"database.module.pgxpool"` + + repository TrustRepository +} + +func (m *TrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *TrustModule) Provision(ctx *sessionmanager.Context) error { + dbMod, err := ctx.GetModule(m.DBModule) + if err != nil { + return fmt.Errorf("getting db module: %w", err) + } + + //nolint:forcetypeassert + db := dbMod.(sessionmanager.Database) + m.repository = sqltrust.NewRepository(db) + + return nil +} + +var _ sessionmanager.Trust = (*TrustModule)(nil) diff --git a/modules/oidctrust/repository.go b/modules/oidctrust/repository.go new file mode 100644 index 00000000..ae03a668 --- /dev/null +++ b/modules/oidctrust/repository.go @@ -0,0 +1,15 @@ +package oidctrust + +import ( + "context" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" +) + +// TrustRepository allows to read OIDC trust data for a tenant stored in the context. +type TrustRepository interface { + Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) + Create(ctx context.Context, trust *trustv1.Trust) error + Delete(ctx context.Context, tenantID string) error + Update(ctx context.Context, trust *trustv1.Trust) error +} diff --git a/internal/trust/repository_test.go b/modules/oidctrust/repository_test.go similarity index 57% rename from internal/trust/repository_test.go rename to modules/oidctrust/repository_test.go index 549e42d3..0155db24 100644 --- a/internal/trust/repository_test.go +++ b/modules/oidctrust/repository_test.go @@ -1,4 +1,4 @@ -package trust_test +package oidctrust_test import ( "context" @@ -6,14 +6,15 @@ import ( "fmt" "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "github.com/pressly/goose/v3" "github.com/testcontainers/testcontainers-go/modules/postgres" - _ "github.com/jackc/pgx/v5/stdlib" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" - migrations "github.com/openkcm/session-manager/sql" + "github.com/openkcm/session-manager/modules/oidctrust" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" + migrations "github.com/openkcm/session-manager/modules/oidctrust/migrations" ) const ( @@ -25,28 +26,28 @@ const ( ) type RepoWrapper struct { - Repo trust.OIDCMappingRepository - MockGet func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) - MockCreate func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error - MockUpdate func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error + Repo oidctrust.TrustRepository + MockGet func(ctx context.Context, tenantID string) (*trustv1.Trust, error) + MockCreate func(ctx context.Context, trust *trustv1.Trust) error MockDelete func(ctx context.Context, tenantID string) error + MockUpdate func(ctx context.Context, trust *trustv1.Trust) error } -var _ trust.OIDCMappingRepository = &RepoWrapper{} +var _ oidctrust.TrustRepository = &RepoWrapper{} -// Create implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Create(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { +// Create implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Create(ctx context.Context, trust *trustv1.Trust) error { if m.MockCreate != nil { - err := m.MockCreate(ctx, tenantID, mapping) + err := m.MockCreate(ctx, trust) if err != nil { return err } } - return m.Repo.Create(ctx, tenantID, mapping) + return m.Repo.Create(ctx, trust) } -// Delete implements oidc.OIDCMappingRepository. +// Delete implements oidc.OIDCTrustRepository. func (m *RepoWrapper) Delete(ctx context.Context, tenantID string) error { if m.MockDelete != nil { err := m.MockDelete(ctx, tenantID) @@ -58,29 +59,29 @@ func (m *RepoWrapper) Delete(ctx context.Context, tenantID string) error { return m.Repo.Delete(ctx, tenantID) } -// Get implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Get(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { +// Get implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { if m.MockGet != nil { _, err := m.MockGet(ctx, tenantID) if err != nil { - return trust.OIDCMapping{}, err + return nil, err } } return m.Repo.Get(ctx, tenantID) } -// Update implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Update(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { +// Update implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Update(ctx context.Context, trust *trustv1.Trust) error { if m.MockUpdate != nil { - err := m.MockUpdate(ctx, tenantID, mapping) + err := m.MockUpdate(ctx, trust) if err != nil { return err } } - return m.Repo.Update(ctx, tenantID, mapping) + return m.Repo.Update(ctx, trust) } -func createRepo(ctx context.Context) (trust.OIDCMappingRepository, error) { +func createRepo(ctx context.Context) (oidctrust.TrustRepository, error) { pgContainer, err := postgres.Run( ctx, "postgres:17-alpine", @@ -110,7 +111,7 @@ func createRepo(ctx context.Context) (trust.OIDCMappingRepository, error) { return nil, err } - return trustsql.NewRepository(dbPool), nil + return sqltrust.NewRepository(&dbWrapper{dbPool}), nil } func migrateDB(ctx context.Context, connStr string) error { @@ -133,3 +134,11 @@ func migrateDB(ctx context.Context, connStr string) error { } return nil } + +type dbWrapper struct { + *pgxpool.Pool +} + +func (w *dbWrapper) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(w.Pool) +} diff --git a/modules/oidctrust/trust.go b/modules/oidctrust/trust.go new file mode 100644 index 00000000..fa0d30de --- /dev/null +++ b/modules/oidctrust/trust.go @@ -0,0 +1,107 @@ +package oidctrust + +import ( + "context" + "errors" + "fmt" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +// Apply implements [sessionmanager.Trust]. +func (m *TrustModule) Apply(ctx context.Context, trust *trustv1.Trust) error { + if _, err := m.repository.Get(ctx, trust.GetTenantId()); err != nil { + err = m.repository.Create(ctx, trust) + if err != nil { + return fmt.Errorf("creating trust for tenant: %w", err) + } + } else { + err = m.repository.Update(ctx, trust) + if err != nil { + return fmt.Errorf("updating trust for tenant: %w", err) + } + } + + return nil +} + +// Block implements [sessionmanager.Trust]. +func (m *TrustModule) Block(ctx context.Context, tenantID string) error { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("getting trust for tenant: %w", err) + } + if trust.GetBlocked() { + return nil + } + + trust.SetBlocked(true) + if err = m.repository.Update(ctx, trust); err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("updating trust for blocking tenant: %w", err) + } + return nil +} + +// Remove implements [sessionmanager.Trust]. +func (m *TrustModule) Remove(ctx context.Context, tenantID string) error { + err := m.repository.Delete(ctx, tenantID) + if err != nil { + return fmt.Errorf("deleting trust for tenant: %w", err) + } + + return nil +} + +// Unblock implements [sessionmanager.Trust]. +func (m *TrustModule) Unblock(ctx context.Context, tenantID string) error { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("getting trust for tenant: %w", err) + } + if !trust.GetBlocked() { + return nil + } + trust.SetBlocked(false) + if err = m.repository.Update(ctx, trust); err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("updating trust for unblocking tenant: %w", err) + } + return nil +} + +// Get implements [sessionmanager.Trust]. +func (m *TrustModule) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + return nil, fmt.Errorf("getting trust from repository: %w", err) + } + + m.resolveExtensions(trust) + return trust, nil +} + +// resolveExtensions sets optional extensions to the Trust message and its details if configured. +func (m *TrustModule) resolveExtensions(trust *trustv1.Trust) { + switch trust.WhichDetails() { + case trustv1.Trust_Oidc_case: + m.resolveOIDCExtensions(trust.GetOidc()) + } +} + +// resolveOIDCExtensions sets optional extensions to the ODIC message if configured. +func (m *TrustModule) resolveOIDCExtensions(oidc *oidcv1.OIDC) { +} diff --git a/modules/oidctrust/trust_test.go b/modules/oidctrust/trust_test.go new file mode 100644 index 00000000..2c67c69f --- /dev/null +++ b/modules/oidctrust/trust_test.go @@ -0,0 +1,677 @@ +package oidctrust_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/gofrs/uuid/v5" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + "github.com/openkcm/session-manager/modules/oidctrust" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" +) + +var repo oidctrust.TrustRepository + +const ( + requestURI = "http://cmk.example.com/ui" + jwksURI = "http://jwks.example.com" +) + +func TestMain(m *testing.M) { + ctx := context.Background() + r, err := createRepo(ctx) + if err != nil { + slogctx.Error(ctx, "error while creating repo", "error", err) + } + + repo = r + + code := m.Run() + os.Exit(code) +} + +func TestService_Apply(t *testing.T) { + ctx := t.Context() + + t.Run("success if", func(t *testing.T) { + t.Run("the trust does not exist", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + err := subj.Apply(ctx, expTrust) + assert.NoError(t, err) + + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + if diff := cmp.Diff(expTrust, actTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) + } + }) + + t.Run("the trust exists", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + err := subj.Apply(ctx, expTrust) + assert.NoError(t, err) + + expUpdatedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(expTrust.GetOidc().GetIssuer()), + JwksUri: new("http://updated-jwks.example.com"), + Audiences: []string{requestURI, "http://new-aud.example.com"}, + }.Build(), + }.Build() + + err = subj.Apply(ctx, expUpdatedTrust) + assert.NoError(t, err) + + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + if diff := cmp.Diff(expUpdatedTrust, actTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) + } + }) + }) + + t.Run("should return error if", func(t *testing.T) { + t.Run("Create returns an error", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + noOfCalls := 0 + wrapper.MockCreate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + assert.Equal(t, expTrust, trust) + noOfCalls++ + return assert.AnError + } + + subj := oidctrust.NewModule(wrapper) + err := subj.Apply(ctx, expTrust) + + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfCalls) + }) + + t.Run("Update returns an error", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + noOfCalls := 0 + wrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + assert.Equal(t, expTrust, trust) + noOfCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(wrapper) + + err := subj.Apply(ctx, expTrust) + assert.NoError(t, err) + err = subj.Apply(ctx, expTrust) + + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfCalls) + }) + }) +} + +func TestService_Block(t *testing.T) { + ctx := t.Context() + + t.Run("success if ", func(t *testing.T) { + t.Run("the trust is unblocked", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expUnblockedTrust) + require.NoError(t, err) + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.Block(ctx, expTenantID) + + // then + assert.NoError(t, err) + + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.True(t, actTrust.GetBlocked()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetIssuer(), actTrust.GetOidc().GetIssuer()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetAudiences(), actTrust.GetOidc().GetAudiences()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetJwksUri(), actTrust.GetOidc().GetJwksUri()) + }) + + t.Run("the trust is blocked then it should not call Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Block(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 0, noOfUpdateCalls) + + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.Equal(t, expBlockedTrust, actTrust) + }) + t.Run("the trust is not found during the Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + // delete the trust before updating to return an error + err := repoWrapper.Repo.Delete(ctx, expTenantID) + assert.NoError(t, err) + return nil + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Block(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 1, noOfUpdateCalls) + }) + t.Run("the trust is not found", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.Block(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + }) + }) + + t.Run("should return error", func(t *testing.T) { + t.Run("if Get returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + noOfGetCalls := 0 + repoWrapper.MockGet = func(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + assert.Equal(t, expTenantID, tenantID) + noOfGetCalls++ + return trustv1.Trust_builder{ + Oidc: oidcv1.OIDC_builder{}.Build(), + }.Build(), assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.Block(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfGetCalls) + }) + + t.Run("if Update returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Block(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfUpdateCalls) + + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.Equal(t, expTrust, actTrust) + }) + }) +} + +func TestService_Unblock(t *testing.T) { + ctx := t.Context() + + t.Run("success if ", func(t *testing.T) { + t.Run("the trust is blocked", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expBlockedTrust) + require.NoError(t, err) + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.Unblock(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.False(t, actTrust.GetBlocked()) + assert.Equal(t, expBlockedTrust.GetOidc().GetIssuer(), actTrust.GetOidc().GetIssuer()) + assert.Equal(t, expBlockedTrust.GetOidc().GetAudiences(), actTrust.GetOidc().GetAudiences()) + assert.Equal(t, expBlockedTrust.GetOidc().GetJwksUri(), actTrust.GetOidc().GetJwksUri()) + }) + + t.Run("the trust is unblocked then it should not call Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expUnblockedTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Unblock(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 0, noOfUpdateCalls) + + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.False(t, actTrust.GetBlocked()) + }) + t.Run("the trust is not found during the Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expUnblockedTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + // delete the trust before updating to return an error + err := repoWrapper.Repo.Delete(ctx, expTenantID) + assert.NoError(t, err) + return nil + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Unblock(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 1, noOfUpdateCalls) + }) + t.Run("the trust is not found", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.Unblock(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + }) + }) + t.Run("should return error", func(t *testing.T) { + t.Run("if Get returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + mockRepo := &RepoWrapper{Repo: repo} + + noOfGetTenantCalls := 0 + mockRepo.MockGet = func(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + assert.Equal(t, expTenantID, tenantID) + noOfGetTenantCalls++ + return trustv1.Trust_builder{ + Oidc: oidcv1.OIDC_builder{}.Build(), + }.Build(), assert.AnError + } + subj := oidctrust.NewModule(mockRepo) + + // when + err := subj.Unblock(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfGetTenantCalls) + }) + + t.Run("if Update returns an error", func(t *testing.T) { + // given + expTenantIDtoUpdate := uuid.Must(uuid.NewV4()).String() + expBlockedTrust := trustv1.Trust_builder{ + TenantId: new(expTenantIDtoUpdate), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantIDtoUpdate, trust.GetTenantId()) + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.Unblock(t.Context(), expTenantIDtoUpdate) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfUpdateCalls) + + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) + assert.NoError(t, err) + assert.Equal(t, expBlockedTrust, actTrust) + }) + }) +} + +func TestService_Remove(t *testing.T) { + ctx := t.Context() + + t.Run("success if", func(t *testing.T) { + t.Run("the trust exists", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expTrust := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expTrust) + require.NoError(t, err) + + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.Remove(ctx, expTenantID) + + // then + assert.NoError(t, err) + + // verify the trust was deleted + _, err = wrapper.Repo.Get(ctx, expTenantID) + assert.Error(t, err) + }) + }) + + t.Run("should return error if", func(t *testing.T) { + t.Run("the trust does not exist", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + // when + err := subj.Remove(ctx, expTenantID) + + // then + assert.Error(t, err) + }) + + t.Run("Delete returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + wrapper := &RepoWrapper{Repo: repo} + + noOfDeleteCalls := 0 + wrapper.MockDelete = func(ctx context.Context, tenantID string) error { + assert.Equal(t, expTenantID, tenantID) + noOfDeleteCalls++ + return assert.AnError + } + + subj := oidctrust.NewModule(wrapper) + + // when + err := subj.Remove(ctx, expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfDeleteCalls) + }) + }) +} + +func TestService_Get(t *testing.T) { + ctx := t.Context() + + repoErr := errors.New("repository error") + + tests := []struct { + name string + trust *trustv1.Trust + repoErr error + wantErr bool + wantErrIs error + }{ + { + name: "returns trust", + trust: trustv1.Trust_builder{ + TenantId: new(uuid.Must(uuid.NewV4()).String()), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build(), + }, + { + name: "returns error when trust does not exist", + wantErr: true, + }, + { + name: "wraps repository error", + repoErr: repoErr, + wantErr: true, + wantErrIs: repoErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var opts []mocktrust.RepositoryOption + if tt.trust != nil { + opts = append(opts, mocktrust.WithTrust(tt.trust)) + } + if tt.repoErr != nil { + opts = append(opts, mocktrust.WithGetError(tt.repoErr)) + } + + mockRepo := mocktrust.NewInMemRepository(opts...) + subj := oidctrust.NewModule(mockRepo) + + var tenantID string + if tt.trust != nil { + tenantID = tt.trust.GetTenantId() + } else { + tenantID = uuid.Must(uuid.NewV4()).String() + } + + got, err := subj.Get(ctx, tenantID) + + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Fatalf("error = %v, want to wrap %v", err, tt.wantErrIs) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.trust, got, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) + } + }) + } +} diff --git a/modules/sessionstore/valkey/module.go b/modules/sessionstore/valkey/module.go new file mode 100644 index 00000000..be27c026 --- /dev/null +++ b/modules/sessionstore/valkey/module.go @@ -0,0 +1,98 @@ +// Package valkey provides the sessionstore.module.valkey module: a +// session.Repository backed by Valkey, configured by the top-level valkey: +// config block. It registers itself with the sessionmanager module registry +// at init time and is loaded by business.Main as a top-level dependency. +package valkey + +import ( + "fmt" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/valkey-io/valkey-go" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/session" + sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" +) + +const moduleID = "sessionstore.module.valkey" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the sessionstore.module.valkey module. It owns a Valkey client +// and exposes a session.Repository backed by it. +type Module struct { + *sessionvalkey.Repository + + Mod string `yaml:"module"` + Host commoncfg.SourceRef `yaml:"host"` + User commoncfg.SourceRef `yaml:"user"` + Password commoncfg.SourceRef `yaml:"password"` + Prefix string `yaml:"prefix"` + SecretRef commoncfg.SecretRef `yaml:"secretRef"` + + client valkey.Client +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(_ *sessionmanager.Context) error { + host, err := commoncfg.LoadValueFromSourceRef(m.Host) + if err != nil { + return fmt.Errorf("loading valkey host: %w", err) + } + user, err := commoncfg.LoadValueFromSourceRef(m.User) + if err != nil { + return fmt.Errorf("loading valkey user: %w", err) + } + password, err := commoncfg.LoadValueFromSourceRef(m.Password) + if err != nil { + return fmt.Errorf("loading valkey password: %w", err) + } + + opts := valkey.ClientOption{ + InitAddress: []string{string(host)}, + Username: string(user), + Password: string(password), + } + + if m.SecretRef.Type == commoncfg.MTLSSecretType { + tlsConfig, err := commoncfg.LoadMTLSConfig(&m.SecretRef.MTLS) + if err != nil { + return fmt.Errorf("loading valkey mTLS config: %w", err) + } + opts.TLSConfig = tlsConfig + } + + client, err := valkey.NewClient(opts) + if err != nil { + return fmt.Errorf("creating valkey client: %w", err) + } + m.client = client + m.Repository = sessionvalkey.NewRepository(client, m.Prefix) + + return nil +} + +func (m *Module) Close() error { + if m.client == nil { + return nil + } + m.client.Close() + return nil +} + +// Compile-time guarantee that the module satisfies session.Repository via the +// embedded *sessionvalkey.Repository. +var _ session.Repository = (*Module)(nil) diff --git a/modules/sessionstore/valkey/module_test.go b/modules/sessionstore/valkey/module_test.go new file mode 100644 index 00000000..68e118dc --- /dev/null +++ b/modules/sessionstore/valkey/module_test.go @@ -0,0 +1,27 @@ +package valkey_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + sessionstorevalkey "github.com/openkcm/session-manager/modules/sessionstore/valkey" +) + +func TestModule_RegistrationAndID(t *testing.T) { + info, err := sessionmanager.GetModule("sessionstore.module.valkey") + require.NoError(t, err) + assert.Equal(t, "sessionstore.module.valkey", info.ID) + + mod := info.New() + require.NotNil(t, mod) + _, ok := mod.(*sessionstorevalkey.Module) + assert.True(t, ok, "New() must return *Module") +} + +func TestModule_CloseBeforeProvisionIsSafe(t *testing.T) { + m := new(sessionstorevalkey.Module) + require.NoError(t, m.Close(), "Close before Provision must not error") +} diff --git a/modules/standard/imports.go b/modules/standard/imports.go new file mode 100644 index 00000000..e6d3c06c --- /dev/null +++ b/modules/standard/imports.go @@ -0,0 +1,13 @@ +package standard + +import ( + _ "github.com/openkcm/session-manager/modules/app/grpcserver" + _ "github.com/openkcm/session-manager/modules/credentials/oauth2" + _ "github.com/openkcm/session-manager/modules/database/pgxpool" + _ "github.com/openkcm/session-manager/modules/grpc/oidcmapping" + _ "github.com/openkcm/session-manager/modules/grpc/session" + _ "github.com/openkcm/session-manager/modules/grpc/trustmapping" + _ "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/oidctrust/migrations" + _ "github.com/openkcm/session-manager/modules/sessionstore/valkey" +) diff --git a/modules_test.go b/modules_test.go new file mode 100644 index 00000000..c5f2b4cc --- /dev/null +++ b/modules_test.go @@ -0,0 +1,82 @@ +package sessionmanager_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// stubModule is a minimal Module used across tests in this file. +type stubModule struct{ id string } + +func (s *stubModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return &stubModule{id: s.id} }, + } +} + +// newRegistry resets the global module registry by registering a fresh set of +// modules under unique IDs so parallel/serial tests don't interfere with each other. +// It returns a cleanup function that can be deferred. +// +// Because the global map is package-level state, each test that touches the +// registry must use IDs that are unique within the whole test binary run. +func uniqueID(t *testing.T, suffix string) string { + t.Helper() + return t.Name() + "/" + suffix +} + +func TestRegisterModule_Success(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + info, err := sessionmanager.GetModule(id) + require.NoError(t, err) + assert.Equal(t, id, info.ID) +} + +func TestRegisterModule_DuplicatePanics(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + assert.Panics(t, func() { + sessionmanager.RegisterModule(&stubModule{id: id}) + }) +} + +func TestGetModule_NotRegistered(t *testing.T) { + _, err := sessionmanager.GetModule("module-that-does-not-exist") + require.Error(t, err) + assert.Contains(t, err.Error(), "not registered") +} + +func TestModules_ContainsRegistered(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + found := false + for info := range sessionmanager.Modules() { + if info.ID == id { + found = true + break + } + } + assert.True(t, found, "registered module should appear in Modules()") +} + +func TestModuleInfo_New(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + info, err := sessionmanager.GetModule(id) + require.NoError(t, err) + require.NotNil(t, info.New) + + instance := info.New() + require.NotNil(t, instance) + assert.Equal(t, id, instance.Module().ID) +} diff --git a/internal/serviceerr/errors.go b/pkg/serviceerr/errors.go similarity index 100% rename from internal/serviceerr/errors.go rename to pkg/serviceerr/errors.go diff --git a/internal/serviceerr/errors_test.go b/pkg/serviceerr/errors_test.go similarity index 99% rename from internal/serviceerr/errors_test.go rename to pkg/serviceerr/errors_test.go index 72c055c9..92371e7a 100644 --- a/internal/serviceerr/errors_test.go +++ b/pkg/serviceerr/errors_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" ) func TestError_Error(t *testing.T) { diff --git a/sql/fs.go b/sql/fs.go deleted file mode 100644 index 91cca1c3..00000000 --- a/sql/fs.go +++ /dev/null @@ -1,6 +0,0 @@ -package migrations - -import "embed" - -//go:embed *.sql -var FS embed.FS diff --git a/sqlc.yaml b/sqlc.yaml index f1f46504..78319835 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -1,12 +1,12 @@ version: "2" sql: - engine: postgresql - queries: ./internal/trust/trustsql/queries.sql - schema: ./sql + queries: ./modules/oidctrust/internal/sql/queries.sql + schema: ./modules/oidctrust/migrations gen: go: package: queries - out: ./internal/trust/trustsql/internal/queries + out: ./modules/oidctrust/internal/sql/queries sql_package: pgx/v5 sql_driver: github.com/jackc/pgx/v5 emit_db_tags: true diff --git a/trust.go b/trust.go new file mode 100644 index 00000000..4f8eb0fc --- /dev/null +++ b/trust.go @@ -0,0 +1,24 @@ +package sessionmanager + +import ( + "context" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" +) + +type Trust interface { + // Apply applies and stores the provided Trust. + Apply(ctx context.Context, trust *trustv1.Trust) error + // Block sets the Blocked flag to true for the trust associated with the given tenantID. + // If the trust is already blocked, it does nothing. + // Returns an error if the trust cannot be retrieved or updated. + Block(ctx context.Context, tenantID string) error + // Remove removes the trust for the given tenantID. + Remove(ctx context.Context, tenantID string) error + // Unblock sets the Blocked flag to false for the trust associated with the given tenantID. + // If the trust is not blocked, it does nothing. + // Returns an error if the trust cannot be retrieved or updated. + Unblock(ctx context.Context, tenantID string) error + // Get returns a trust message with optional extensions set. + Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) +}