From 6189aa5bcb53be77418fc405e49ee39f17d4604c Mon Sep 17 00:00:00 2001 From: Danylo Shevchenko Date: Thu, 7 May 2026 13:20:40 +0200 Subject: [PATCH 1/5] feat: module extensions Signed-off-by: Danylo Shevchenko --- .golangci.yaml | 1 + Makefile | 26 +- Taskfile.yaml | 2 +- charts/session-manager/values-dev.yaml | 2 +- charts/session-manager/values.yaml | 2 +- cmd/session-manager/main.go | 85 +-- cmd/session-manager/maincmd/main.go | 87 +++ config.yaml | 2 +- context.go | 92 +++ context_test.go | 244 +++++++ database.go | 20 + go.mod | 57 +- go.sum | 98 ++- integration/api-server-status_test.go | 33 +- integration/grpc_test.go | 411 ++++++----- integration/infra_test.go | 43 +- integration/migrate_test.go | 102 +-- integration/session_grpc_test.go | 118 +-- internal/business/business.go | 106 ++- internal/business/business_test.go | 158 +--- internal/business/housekeeper.go | 27 +- internal/business/housekeeper_test.go | 46 +- internal/business/migrate.go | 50 +- internal/business/migrate_test.go | 58 -- internal/business/server/grpc_server.go | 6 +- internal/business/server/grpc_server_test.go | 2 +- internal/business/server/openapi.go | 2 +- internal/business/server/openapi_test.go | 2 +- internal/cmdutils/cmdutils.go | 46 +- internal/cmdutils/cmdutils_test.go | 24 - internal/config/config.go | 97 ++- internal/config/connstr.go | 27 - internal/config/connstr_test.go | 113 --- internal/config/load.go | 106 +++ internal/config/parser.go | 23 + internal/dbtest/postgrestest/postgres.go | 8 +- internal/grpc/import_test.go | 12 + internal/grpc/oidcmapping.go | 129 ---- internal/grpc/oidcmapping_test.go | 417 ++++++----- internal/grpc/session.go | 46 +- internal/grpc/session_test.go | 283 ++++---- internal/grpc/trustmapping.go | 139 ++++ internal/session/housekeeper.go | 14 +- internal/session/housekeeper_test.go | 197 ++++- internal/session/import_test.go | 12 + internal/session/manager.go | 128 ++-- internal/session/manager_cookie_test.go | 44 +- internal/session/manager_test.go | 140 ++-- internal/session/mock/repository.go | 2 +- internal/session/valkey/repository.go | 2 +- internal/session/valkey/store.go | 2 +- internal/trust/mapping.go | 25 - internal/trust/mapping_test.go | 95 --- internal/trust/repository.go | 11 - internal/trust/service.go | 95 --- internal/trust/service_test.go | 540 -------------- internal/trust/trustsql/errors.go | 18 - internal/trust/trustsql/errors_test.go | 47 -- .../trust/trustsql/internal/queries/models.go | 20 - internal/trust/trustsql/repository.go | 157 ---- internal/trust/trustsql/repository_test.go | 302 -------- modules.go | 57 ++ modules/database/pgxpool/module.go | 107 +++ modules/oidctrust/export.go | 15 + modules/oidctrust/export_test.go | 7 + modules/oidctrust/internal/sql/export_test.go | 15 + .../oidctrust/internal/sql}/queries.sql | 4 - .../oidctrust/internal/sql}/queries/db.go | 0 .../oidctrust/internal/sql/queries/models.go | 19 + .../internal/sql}/queries/queries.sql.go | 50 +- modules/oidctrust/internal/sql/sql.go | 154 ++++ modules/oidctrust/internal/sql/sql_test.go | 305 ++++++++ modules/oidctrust/mapping.go | 110 +++ modules/oidctrust/mapping_test.go | 677 ++++++++++++++++++ .../oidctrust/migrations}/00001_init.sql | 0 .../00002_add_properties_to_providers.sql | 0 .../migrations}/00003_tenant_trust.sql | 0 .../migrations}/00004_single_tenant.sql | 0 .../migrations/00005_remove_properties.sql | 11 + modules/oidctrust/migrations/migration.go | 62 ++ .../oidctrust/mocks}/repository.go | 38 +- modules/oidctrust/module.go | 48 ++ modules/oidctrust/repository.go | 15 + .../oidctrust}/repository_test.go | 57 +- modules/standard/imports.go | 7 + modules_test.go | 82 +++ {internal => pkg}/serviceerr/errors.go | 0 {internal => pkg}/serviceerr/errors_test.go | 2 +- sql/fs.go | 6 - sqlc.yaml | 6 +- trust.go | 15 + 91 files changed, 3968 insertions(+), 3104 deletions(-) create mode 100644 cmd/session-manager/maincmd/main.go create mode 100644 context.go create mode 100644 context_test.go create mode 100644 database.go delete mode 100644 internal/business/migrate_test.go delete mode 100644 internal/config/connstr.go delete mode 100644 internal/config/connstr_test.go create mode 100644 internal/config/load.go create mode 100644 internal/config/parser.go create mode 100644 internal/grpc/import_test.go delete mode 100644 internal/grpc/oidcmapping.go create mode 100644 internal/grpc/trustmapping.go create mode 100644 internal/session/import_test.go delete mode 100644 internal/trust/mapping.go delete mode 100644 internal/trust/mapping_test.go delete mode 100644 internal/trust/repository.go delete mode 100644 internal/trust/service.go delete mode 100644 internal/trust/service_test.go delete mode 100644 internal/trust/trustsql/errors.go delete mode 100644 internal/trust/trustsql/errors_test.go delete mode 100644 internal/trust/trustsql/internal/queries/models.go delete mode 100644 internal/trust/trustsql/repository.go delete mode 100644 internal/trust/trustsql/repository_test.go create mode 100644 modules.go create mode 100644 modules/database/pgxpool/module.go create mode 100644 modules/oidctrust/export.go create mode 100644 modules/oidctrust/export_test.go create mode 100644 modules/oidctrust/internal/sql/export_test.go rename {internal/trust/trustsql => modules/oidctrust/internal/sql}/queries.sql (90%) rename {internal/trust/trustsql/internal => modules/oidctrust/internal/sql}/queries/db.go (100%) create mode 100644 modules/oidctrust/internal/sql/queries/models.go rename {internal/trust/trustsql/internal => modules/oidctrust/internal/sql}/queries/queries.sql.go (68%) create mode 100644 modules/oidctrust/internal/sql/sql.go create mode 100644 modules/oidctrust/internal/sql/sql_test.go create mode 100644 modules/oidctrust/mapping.go create mode 100644 modules/oidctrust/mapping_test.go rename {sql => modules/oidctrust/migrations}/00001_init.sql (100%) rename {sql => modules/oidctrust/migrations}/00002_add_properties_to_providers.sql (100%) rename {sql => modules/oidctrust/migrations}/00003_tenant_trust.sql (100%) rename {sql => modules/oidctrust/migrations}/00004_single_tenant.sql (100%) create mode 100644 modules/oidctrust/migrations/00005_remove_properties.sql create mode 100644 modules/oidctrust/migrations/migration.go rename {internal/trust/trustmock => modules/oidctrust/mocks}/repository.go (56%) create mode 100644 modules/oidctrust/module.go create mode 100644 modules/oidctrust/repository.go rename {internal/trust => modules/oidctrust}/repository_test.go (57%) create mode 100644 modules/standard/imports.go create mode 100644 modules_test.go rename {internal => pkg}/serviceerr/errors.go (100%) rename {internal => pkg}/serviceerr/errors_test.go (99%) delete mode 100644 sql/fs.go create mode 100644 trust.go diff --git a/.golangci.yaml b/.golangci.yaml index 89e0440e..99fc67a9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -7,6 +7,7 @@ run: linters: default: all disable: + - gomodguard # Replaced by gomodguard_v2, enabled by default: all - nilerr - tagliatelle - bodyclose diff --git a/Makefile b/Makefile index a05add0c..e8b7b491 100644 --- a/Makefile +++ b/Makefile @@ -106,37 +106,41 @@ docker-dev-build: .PHONY: codegen codegen: go generate ./... - go run github.com/sqlc-dev/sqlc/cmd/sqlc@latest generate + go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate .PHONY: clean clean: - rm -f cover.out cover.html session-manager - rm -rf cover/ + @rm -f cover.out cover.html session-manager + @rm -rf cover/ + +.PHONY: fix-lint +fix-lint: + golangci-lint run --fix --build-tags=integration ./... .PHONY: lint lint: - golangci-lint run ./... + golangci-lint run --build-tags=integration ./... .PHONY: build build: go build ./cmd/session-manager .PHONY: test -test: clean install-gotestsum +test: clean @mkdir -p cover/integration cover/unit @go clean -testcache - gotestsum --junitfile="${CURDIR}/junit-unit.xml" --format=testname -- -count=1 -race -cover ./... -args -test.gocoverdir="${CURDIR}/cover/unit" - GOCOVERDIR="${CURDIR}/cover/integration" gotestsum --junitfile="${CURDIR}/junit-integration.xml" --format=testname -- -v -count=1 -race -tags=integration ./integration + @go tool gotest.tools/gotestsum --junitfile="${CURDIR}/junit-unit.xml" --format=dots-v2 -- -count=1 -race -cover ./... -args -test.gocoverdir="${CURDIR}/cover/unit" + @GOCOVERDIR=${CURDIR}/cover/integration go tool gotest.tools/gotestsum --junitfile="${CURDIR}/junit-integration.xml" --format=dots-v2 -- -v -count=1 -race -tags=integration ./integration @go tool covdata textfmt -i=./cover/unit,./cover/integration -o cover.out @grep -v 'github.com/openkcm/session-manager/internal/openapi/' cover.out > cover.tmp && mv cover.tmp cover.out @grep -v 'github.com/openkcm/session-manager/internal/dbtest/' cover.out > cover.tmp && mv cover.tmp cover.out - @grep -v 'github.com/openkcm/session-manager/internal/trust/trustmock/' cover.out > cover.tmp && mv cover.tmp cover.out + @grep -v 'github.com/openkcm/session-manager/modules/oidctrust/mocks/' cover.out > cover.tmp && mv cover.tmp cover.out @grep -v 'github.com/openkcm/session-manager/internal/session/mock/' cover.out > cover.tmp && mv cover.tmp cover.out @go tool cover -func=cover.out - @echo "On a Mac, you can use the following command to open the coverage report in the browser\ngo tool cover -html=cover.out -o cover.html && open cover.html" + @echo "On a Mac, you can use the following command to open the coverage report in the browser\ngo tool cover -html=cover.out" .PHONY: helm-test helm-test: @@ -172,10 +176,6 @@ helm-integration-test-run: k3d-teardown: k3d cluster delete $(K3D_CLUSTER_NAME) -.PHONY: install-gotestsum -install-gotestsum: - (cd /tmp && go install gotest.tools/gotestsum@latest) - .PHONY: image image: docker build -t ${IMG} . diff --git a/Taskfile.yaml b/Taskfile.yaml index 5bafdc5f..bfaf9b3a 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -6,6 +6,6 @@ includes: flatten: true excludes: [] # put task names in here which are overwritten in this file vars: - CODE_DIRS: "{{.ROOT_DIR}}/cmd/... {{.ROOT_DIR}}/internal/... {{.ROOT_DIR}}/integration/... {{.ROOT_DIR}}/sql/..." + CODE_DIRS: "{{.ROOT_DIR}}/cmd/... {{.ROOT_DIR}}/internal/... {{.ROOT_DIR}}/integration/... {{.ROOT_DIR}} {{.ROOT_DIR}}/modules/..." COMPONENTS: session-manager REPO_URL: https://github.com/openkcm/session-manager diff --git a/charts/session-manager/values-dev.yaml b/charts/session-manager/values-dev.yaml index d9d30d39..6d80faec 100644 --- a/charts/session-manager/values-dev.yaml +++ b/charts/session-manager/values-dev.yaml @@ -296,7 +296,7 @@ config: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded diff --git a/charts/session-manager/values.yaml b/charts/session-manager/values.yaml index 94666994..69fd0013 100644 --- a/charts/session-manager/values.yaml +++ b/charts/session-manager/values.yaml @@ -304,7 +304,7 @@ config: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded diff --git a/cmd/session-manager/main.go b/cmd/session-manager/main.go index e23cc5be..045cf631 100644 --- a/cmd/session-manager/main.go +++ b/cmd/session-manager/main.go @@ -1,87 +1,14 @@ package main import ( - "context" - "log/slog" - "os" - "os/signal" - "time" - - "github.com/openkcm/common-sdk/pkg/utils" - "github.com/spf13/cobra" - - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/cmd/session-manager/apiserver" - "github.com/openkcm/session-manager/cmd/session-manager/housekeeper" - "github.com/openkcm/session-manager/cmd/session-manager/migrate" -) - -var ( - // BuildInfo will be set by the build system - BuildInfo = "{}" - - isVersionCmd bool - gracefulShutdown time.Duration + "github.com/openkcm/session-manager/cmd/session-manager/maincmd" + _ "github.com/openkcm/session-manager/modules/standard" ) -var versionCmd = &cobra.Command{ - Use: "version", - Short: "Session Manager Version", - RunE: func(cmd *cobra.Command, _ []string) error { - isVersionCmd = true - - value, err := utils.ExtractFromComplexValue(BuildInfo) - if err != nil { - return err - } - - slog.InfoContext(cmd.Context(), value) - - return nil - }, -} - -func rootCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "session-manager", - Short: "Session Manager", - Long: "KCM Session Manager, implementing the OIDC authorization code flow.", - } - - cmd.PersistentFlags().DurationVar(&gracefulShutdown, "graceful-shutdown", 1*time.Second, "graceful shutdown") - - cmd.AddCommand( - versionCmd, - apiserver.Cmd(BuildInfo), - housekeeper.Cmd(BuildInfo), - migrate.Cmd(BuildInfo), - ) - - return cmd -} - -func execute() error { - ctx, cancelOnSignal := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) - defer cancelOnSignal() - - err := rootCmd().ExecuteContext(ctx) - if err != nil { - slogctx.Error(ctx, "failed to start the application", "error", err) - return err - } - - if !isVersionCmd { - slogctx.Info(ctx, "Graceful shutdown", "duration", gracefulShutdown) - time.Sleep(gracefulShutdown) - } - - return nil -} +// BuildInfo will be set by the build system +var BuildInfo = "{}" func main() { - err := execute() - if err != nil { - os.Exit(1) - } + maincmd.BuildInfo = BuildInfo + maincmd.Main() } diff --git a/cmd/session-manager/maincmd/main.go b/cmd/session-manager/maincmd/main.go new file mode 100644 index 00000000..4b7d197b --- /dev/null +++ b/cmd/session-manager/maincmd/main.go @@ -0,0 +1,87 @@ +package maincmd + +import ( + "context" + "log/slog" + "os" + "os/signal" + "time" + + "github.com/openkcm/common-sdk/pkg/utils" + "github.com/spf13/cobra" + + slogctx "github.com/veqryn/slog-context" + + "github.com/openkcm/session-manager/cmd/session-manager/apiserver" + "github.com/openkcm/session-manager/cmd/session-manager/housekeeper" + "github.com/openkcm/session-manager/cmd/session-manager/migrate" +) + +var ( + // BuildInfo will be set by the build system + BuildInfo = "{}" + + isVersionCmd bool + gracefulShutdown time.Duration +) + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Session Manager Version", + RunE: func(cmd *cobra.Command, _ []string) error { + isVersionCmd = true + + value, err := utils.ExtractFromComplexValue(BuildInfo) + if err != nil { + return err + } + + slog.InfoContext(cmd.Context(), value) + + return nil + }, +} + +func rootCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "session-manager", + Short: "Session Manager", + Long: "KCM Session Manager, implementing the OIDC authorization code flow.", + } + + cmd.PersistentFlags().DurationVar(&gracefulShutdown, "graceful-shutdown", 1*time.Second, "graceful shutdown") + + cmd.AddCommand( + versionCmd, + apiserver.Cmd(BuildInfo), + housekeeper.Cmd(BuildInfo), + migrate.Cmd(BuildInfo), + ) + + return cmd +} + +func execute() error { + ctx, cancelOnSignal := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancelOnSignal() + + err := rootCmd().ExecuteContext(ctx) + if err != nil { + slogctx.Error(ctx, "failed to start the application", "error", err) + return err + } + + if !isVersionCmd { + slogctx.Info(ctx, "Graceful shutdown", "duration", gracefulShutdown) + time.Sleep(gracefulShutdown) + } + + return nil +} + +func Main() { + err := execute() + if err != nil { + os.Exit(1) + } +} diff --git a/config.yaml b/config.yaml index b289f853..a724449f 100644 --- a/config.yaml +++ b/config.yaml @@ -205,7 +205,7 @@ valkey: value: secret prefix: session-manager secretRef: - type: insecure # Supports "mtls" or "insecure" + type: insecure # Supported values: "mtls", "insecure", "client_secret_post" # mtls: # cert: # source: embedded diff --git a/context.go b/context.go new file mode 100644 index 00000000..fe70fa32 --- /dev/null +++ b/context.go @@ -0,0 +1,92 @@ +package sessionmanager + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + + slogctx "github.com/veqryn/slog-context" +) + +type Context struct { + //nolint:containedctx + context.Context + + mods map[string]Module +} + +func (c *Context) cloneWithParent(parent context.Context) *Context { + return &Context{ + Context: parent, + mods: c.mods, + } +} + +func (c *Context) WithValue(key, val any) *Context { + return c.cloneWithParent(context.WithValue(c.Context, key, val)) +} + +func NewContext(ctx context.Context) (*Context, context.CancelCauseFunc) { + ctx, cancelCause := context.WithCancelCause(ctx) + c := &Context{Context: ctx, mods: make(map[string]Module)} + return c, func(cause error) { + cancelCause(cause) + for name, mod := range c.mods { + if closer, ok := mod.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close a module", "module", name, "error", err) + } + } + } + } +} + +type ExtensionConfig interface { + Module() string + UnmarshalExtension(into Module) error +} + +func (c *Context) GetModule(id string) (Module, error) { + if mod, ok := c.mods[id]; ok { + return mod, nil + } + + return nil, errors.New("module is not loaded") +} + +func (c *Context) LoadModule(cfg ExtensionConfig) (Module, error) { + modInfo, err := GetModule(cfg.Module()) + if err != nil { + return nil, fmt.Errorf("getting module %q: %w", reflect.TypeOf(cfg), err) + } + + if _, ok := c.mods[modInfo.ID]; ok { + return nil, errors.New("module has already been loaded") + } + + slogctx.Debug(c, "loading module", "module", modInfo.ID) + + mod := modInfo.New() + rv := reflect.ValueOf(mod) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Struct { + if err := cfg.UnmarshalExtension(mod); err != nil { + return nil, fmt.Errorf("unmarshaling extension %s: %w", modInfo.ID, err) + } + } + + slogctx.Debug(c, "instantinated module", "module", modInfo.ID) + + if provisioner, ok := mod.(Provisioner); ok { + if err := provisioner.Provision(c); err != nil { + return nil, fmt.Errorf("provisioning module: %w", err) + } + + slogctx.Debug(c, "provisioned module", "module", modInfo.ID) + } + + c.mods[modInfo.ID] = mod + + return mod, nil +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 00000000..745d25e7 --- /dev/null +++ b/context_test.go @@ -0,0 +1,244 @@ +package sessionmanager_test + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// provisionableModule records whether Provision was called. +type provisionableModule struct { + stubModule + + provisioned bool +} + +func (m *provisionableModule) Provision(_ *sessionmanager.Context) error { + m.provisioned = true + return nil +} + +// failingProvisionerModule always returns an error from Provision. +type failingProvisionerModule struct{ stubModule } + +func (m *failingProvisionerModule) Provision(_ *sessionmanager.Context) error { + return errors.New("provision failed") +} + +// closableModule records whether Close was called. +type closableModule struct { + stubModule + + closed bool +} + +func (m *closableModule) Close() error { + m.closed = true + return nil +} + +// closeErrModule returns an error from Close (exercises the error-log path). +type closeErrModule struct{ stubModule } + +func (m *closeErrModule) Close() error { return errors.New("close error") } + +// simpleExtensionConfig is a minimal ExtensionConfig that references a registered module. +type simpleExtensionConfig struct{ moduleID string } + +func (c *simpleExtensionConfig) Module() string { return c.moduleID } +func (c *simpleExtensionConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } + +// failingUnmarshalConfig returns an error from UnmarshalExtension. +type failingUnmarshalConfig struct{ moduleID string } + +func (c *failingUnmarshalConfig) Module() string { return c.moduleID } +func (c *failingUnmarshalConfig) UnmarshalExtension(_ sessionmanager.Module) error { + return errors.New("unmarshal failed") +} + +// customNewModule registers a module whose New() function delegates to newFn. +type customNewModule struct { + id string + newFn func() sessionmanager.Module +} + +func (m *customNewModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: m.id, New: m.newFn} +} + +func TestNewContext_CancelCloseModules(t *testing.T) { + id := uniqueID(t, "closable") + cm := &closableModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return cm }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + cancel(nil) + assert.True(t, cm.closed, "Close() should be called when context is cancelled") +} + +func TestNewContext_CancelWithCause(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + + cause := errors.New("test cause") + cancel(cause) + + assert.ErrorIs(t, context.Cause(ctx), cause) +} + +func TestNewContext_CloseErrorIsHandled(t *testing.T) { + id := uniqueID(t, "closeerr") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &closeErrModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + // Should not panic even though Close() returns an error. + assert.NotPanics(t, func() { cancel(nil) }) +} + +func TestContext_WithValue(t *testing.T) { + type ctxKey struct{} + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + ctx2 := ctx.WithValue(ctxKey{}, "hello") + assert.Equal(t, "hello", ctx2.Value(ctxKey{})) + // Original context should not carry the value. + assert.Nil(t, ctx.Value(ctxKey{})) +} + +func TestLoadModule_Success(t *testing.T) { + id := uniqueID(t, "prov") + pm := &provisionableModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return pm }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + mod, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + require.NotNil(t, mod) + assert.True(t, pm.provisioned) +} + +func TestLoadModule_UnknownModule(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: "no-such-module"}) + require.Error(t, err) +} + +func TestLoadModule_DuplicateReturnsError(t *testing.T) { + id := uniqueID(t, "dup") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + _, err = ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "already been loaded") +} + +func TestLoadModule_ProvisionError(t *testing.T) { + id := uniqueID(t, "failprov") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &failingProvisionerModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "provision failed") +} + +func TestLoadModule_UnmarshalError(t *testing.T) { + id := uniqueID(t, "unmarshalerr") + // Use a pointer-to-struct module so the unmarshal branch is reached. + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&failingUnmarshalConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unmarshal failed") +} + +func TestGetModule_AfterLoad(t *testing.T) { + id := uniqueID(t, "get") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + mod, err := ctx.GetModule(id) + require.NoError(t, err) + assert.NotNil(t, mod) +} + +func TestGetModule_NotLoaded(t *testing.T) { + id := uniqueID(t, "notloaded") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + // Never call LoadModule — GetModule should return an error. + _, err := ctx.GetModule(id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not loaded") +} + +// Ensure stubModule satisfies the Module interface at compile time. +var _ sessionmanager.Module = (*stubModule)(nil) + +// Ensure provisionableModule satisfies Provisioner at compile time. +var _ sessionmanager.Provisioner = (*provisionableModule)(nil) + +// Ensure closableModule satisfies io.Closer at compile time. +var _ io.Closer = (*closableModule)(nil) diff --git a/database.go b/database.go new file mode 100644 index 00000000..6ce95787 --- /dev/null +++ b/database.go @@ -0,0 +1,20 @@ +package sessionmanager + +import ( + "context" + "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type Database interface { + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + STDAdapter() *sql.DB +} + +type Migrate interface { + Migrate(ctx context.Context) error +} diff --git a/go.mod b/go.mod index 3cf7c127..ec1e57be 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,15 @@ module github.com/openkcm/session-manager -go 1.26.0 - -toolchain go1.26.2 +go 1.26.3 tool ( github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen github.com/sqlc-dev/sqlc/cmd/sqlc + gotest.tools/gotestsum ) require ( - github.com/XSAM/otelsql v0.42.0 + github.com/creasty/defaults v1.8.0 github.com/exaring/otelpgx v0.10.0 github.com/go-jose/go-jose/v4 v4.1.4 github.com/go-viper/mapstructure/v2 v2.5.0 @@ -19,9 +18,11 @@ require ( github.com/google/go-cmp v0.7.0 github.com/jackc/pgx/v5 v5.9.2 github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/knadh/koanf/providers/file v1.2.1 + github.com/knadh/koanf/v2 v2.3.4 github.com/moby/moby/api v1.54.2 github.com/oapi-codegen/runtime v1.4.0 - github.com/openkcm/api-sdk v0.17.0 + github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca github.com/openkcm/common-sdk v1.16.0 github.com/pressly/goose/v3 v3.27.1 github.com/samber/oops v1.21.0 @@ -35,6 +36,7 @@ require ( go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 google.golang.org/grpc v1.81.1 + google.golang.org/protobuf v1.36.11 ) require ( @@ -44,9 +46,11 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Dynatrace/OneAgent-SDK-for-Go v1.1.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/XSAM/otelsql v0.42.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bitfield/gotestdox v0.2.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -56,10 +60,10 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect - github.com/creasty/defaults v1.8.0 // indirect github.com/cubicdaiya/gonp v1.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect + github.com/dnephin/pflag v1.0.7 // indirect github.com/docker/go-connections v0.7.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect @@ -67,6 +71,7 @@ require ( github.com/ebitengine/purego v0.10.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/fatih/structtag v1.2.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.10.1 // indirect @@ -79,8 +84,9 @@ require ( github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/google/cel-go v0.26.1 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 // indirect github.com/hashicorp/go-version v1.9.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -89,13 +95,17 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.5 // indirect - github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 // indirect + github.com/klauspost/compress v1.18.6 // indirect + github.com/knadh/koanf/maps v0.1.2 // indirect + github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.21 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mfridman/interpolate v0.0.2 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.2.0 // indirect github.com/moby/moby/client v0.4.1 // indirect @@ -116,7 +126,7 @@ require ( github.com/oliveagle/jsonpath v0.1.4 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pelletier/go-toml/v2 v2.3.1 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pganalyze/pg_query_go/v6 v6.1.0 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect @@ -135,11 +145,11 @@ require ( github.com/riza-io/grpc-go v0.2.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/samber/lo v1.53.0 // indirect - github.com/samber/slog-common v0.21.0 // indirect + github.com/samber/slog-common v0.22.0 // indirect github.com/samber/slog-formatter v1.3.0 // indirect github.com/samber/slog-multi v1.8.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect - github.com/shirou/gopsutil/v4 v4.26.3 // indirect + github.com/shirou/gopsutil/v4 v4.26.4 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/speakeasy-api/jsonpath v0.6.0 // indirect github.com/speakeasy-api/openapi-overlay v0.10.2 // indirect @@ -152,16 +162,16 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/testcontainers/testcontainers-go v0.42.0 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect - github.com/tklauser/go-sysconf v0.3.16 // indirect - github.com/tklauser/numcpus v0.11.0 // indirect + github.com/tklauser/go-sysconf v0.4.0 // indirect + github.com/tklauser/numcpus v0.12.0 // indirect github.com/veqryn/slog-context/otel v0.9.0 // indirect github.com/vmware-labs/yaml-jsonpath v0.3.2 // indirect github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 // indirect github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/collector/featuregate v1.57.0 // indirect - go.opentelemetry.io/collector/pdata v1.57.0 // indirect + go.opentelemetry.io/collector/featuregate v1.58.0 // indirect + go.opentelemetry.io/collector/pdata v1.58.0 // indirect go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect @@ -184,21 +194,22 @@ require ( go.uber.org/zap v1.27.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.50.0 // indirect + golang.org/x/crypto v0.51.0 // indirect golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect golang.org/x/mod v0.35.0 // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/net v0.54.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/term v0.43.0 // indirect + golang.org/x/text v0.37.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.44.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect - google.golang.org/protobuf v1.36.11 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/gotestsum v1.13.0 // indirect modernc.org/libc v1.72.1 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index 61eac509..2fcf5957 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= +github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= @@ -57,6 +59,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= +github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf0mN39c= github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -74,6 +78,8 @@ github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMD github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0= github.com/exaring/otelpgx v0.10.0 h1:NGGegdoBQM3jNZDKG8ENhigUcgBN7d7943L0YlcIpZc= github.com/exaring/otelpgx v0.10.0/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -136,10 +142,12 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20260402051712-545e8a4df936 h1:EwtI+Al+DeppwYX2oXJCETMO23COyaKGP6fHVpkpWpg= github.com/google/pprof v0.0.0-20260402051712-545e8a4df936/go.mod h1:MxpfABSjhmINe3F1It9d+8exIHFvUqtLIRCdOGNXqiI= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 h1:5VipnvEpbqr2gA2VbM+nYVbkIF28c5ZQfqCBQ5g2xfk= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0/go.mod h1:Hyl3n6Twe1hvtd9XUXDec4pTvgMSEixRuQKPTMH2bNs= github.com/hashicorp/go-version v1.9.0 h1:CeOIz6k+LoN3qX9Z0tyQrPtiB1DFYRPfCIBtaXPSCnA= github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -165,8 +173,14 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= -github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= -github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/knadh/koanf/maps v0.1.2 h1:RBfmAW5CnZT+PJ1CVc1QSJKf4Xu9kxfQgYVQSu8hpbo= +github.com/knadh/koanf/maps v0.1.2/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= +github.com/knadh/koanf/providers/file v1.2.1 h1:bEWbtQwYrA+W2DtdBrQWyXqJaJSG3KrP3AESOJYp9wM= +github.com/knadh/koanf/providers/file v1.2.1/go.mod h1:bp1PM5f83Q+TOUu10J/0ApLBd9uIzg+n9UgthfY+nRA= +github.com/knadh/koanf/v2 v2.3.4 h1:fnynNSDlujWE+v83hAp8wKr/cdoxHLO0629SN+U8Urc= +github.com/knadh/koanf/v2 v2.3.4/go.mod h1:gRb40VRAbd4iJMYYD5IxZ6hfuopFcXBpc9bbQpZwo28= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -178,18 +192,24 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 h1:PwQumkgq4/acIiZhtifTV5OUqqiP82UAl0h87xj/l9k= -github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= +github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e h1:Q6MvJtQK/iRcRtzAscm/zF23XxJlbECiGPyRicsX+Ak= +github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/go-archive v0.2.0 h1:zg5QDUM2mi0JIM9fdQZWC7U8+2ZfixfTYoHL7rWUcP8= @@ -252,13 +272,13 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/openkcm/api-sdk v0.17.0 h1:AwDrJvaEj6f35JAohwzILU3nMTeD19J3ycebq0weOMw= -github.com/openkcm/api-sdk v0.17.0/go.mod h1:ffjao8Qr0k9FbtYWmPWWHf52tfBec2CRN4IBgj+ncqo= +github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca h1:XsZ2DB1EpXnT5XQaHT29acprdvH1kRFXsE/FHVisHx4= +github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= github.com/openkcm/common-sdk v1.16.0 h1:pmLXRHvjqg+8ATEyzXarCRiRghw/8pXGn2OtoYuMEIU= github.com/openkcm/common-sdk v1.16.0/go.mod h1:4umveCyatAaTi6dSQgwaBg1O/wqHr4sjzuMIQhEuX1o= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= -github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= -github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= +github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pganalyze/pg_query_go/v6 v6.1.0 h1:jG5ZLhcVgL1FAw4C/0VNQaVmX1SUJx71wBGdtTtBvls= @@ -304,8 +324,8 @@ github.com/samber/lo v1.53.0 h1:t975lj2py4kJPQ6haz1QMgtId2gtmfktACxIXArw3HM= github.com/samber/lo v1.53.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/samber/oops v1.21.0 h1:18atcO4oEigNFuGXqr3NZWZ6P0XOSEXyBSAMXdQRxTc= github.com/samber/oops v1.21.0/go.mod h1:Hsm/sKPxtCfPh0w/cE3xVoRfSiE1joDRiStPAsmG9bo= -github.com/samber/slog-common v0.21.0 h1:Wo2hTly1Br5RjYqX/BTWJJeDnTE85oWk/7vqlpZuAUc= -github.com/samber/slog-common v0.21.0/go.mod h1:d/6OaSlzdkl9PFpfRLgn8FwY1OW6EFmPtBpsHX4MrU0= +github.com/samber/slog-common v0.22.0 h1:WyPxYRg/c5xUmxZJbtd0QgysHlLBhRA+MngKdJieHxE= +github.com/samber/slog-common v0.22.0/go.mod h1:d/6OaSlzdkl9PFpfRLgn8FwY1OW6EFmPtBpsHX4MrU0= github.com/samber/slog-formatter v1.3.0 h1:dpvLVSX883WSI222gUtEMLRd5ZOlKkmQlkrovdVZ9uA= github.com/samber/slog-formatter v1.3.0/go.mod h1:9y2j6qgrCpa7B5Kbv/sKp1ak7wJ91tsswp1BHOUSukc= github.com/samber/slog-multi v1.8.0 h1:E05c1wnQ+8M58oQDBABlJ4TEIJWssNgtckso3zlaLlI= @@ -314,8 +334,8 @@ github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc= -github.com/shirou/gopsutil/v4 v4.26.3/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= +github.com/shirou/gopsutil/v4 v4.26.4 h1:B4SXVbcwTyrocPHEmWBC4uCYr4Xcu3MK1TXqbprAOWY= +github.com/shirou/gopsutil/v4 v4.26.4/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/speakeasy-api/jsonpath v0.6.0 h1:IhtFOV9EbXplhyRqsVhHoBmmYjblIRh5D1/g8DHMXJ8= @@ -357,10 +377,10 @@ github.com/testcontainers/testcontainers-go/modules/valkey v0.42.0 h1:SL15Mh/Jmo github.com/testcontainers/testcontainers-go/modules/valkey v0.42.0/go.mod h1:rKMKPmE5065l6Jk/HWu4D27cDysitXO8MQoWfwKMPg8= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= -github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= -github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= -github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= -github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/tklauser/go-sysconf v0.4.0 h1:7H0uAN+7RkwWRaxhYXDLqa5V3LPrJeV8wmD9dRUgPQU= +github.com/tklauser/go-sysconf v0.4.0/go.mod h1:8mTNWyog7H+MpKijp4VmKJAd2bbYQ2zuUwkYRbUArPI= +github.com/tklauser/numcpus v0.12.0 h1:NR85qdvHA9pFse3x3weVZ0r0ST8R6l5RHbZrlRaqob4= +github.com/tklauser/numcpus v0.12.0/go.mod h1:ABHeXzJnr/qqwguhClkZKT1/8VABcYrsyUiUGobwWJg= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/valkey-io/valkey-go v1.0.75 h1:cfq9DODW2ntuUgyHJmFWb4/p+xpLpQB1t5SQyWM9uJ4= @@ -380,12 +400,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/collector/featuregate v1.57.0 h1:KPDSUKYn6MHwgyGRSGPPcW/G96HH93pxuvvPwM+R8nY= -go.opentelemetry.io/collector/featuregate v1.57.0/go.mod h1:4ga1QBMPEejXXmpyJS8lmaRpknJ3Lb9Bvk6e420bUFU= -go.opentelemetry.io/collector/internal/testutil v0.151.0 h1:CFjDItLuqzblItOsnK6IPSdrsOaZCaDjYpB8qWG+XHI= -go.opentelemetry.io/collector/internal/testutil v0.151.0/go.mod h1:Jkjs6rkqs973LqgZ0Fe3zrokQRKULYXPIf4HuqStiEE= -go.opentelemetry.io/collector/pdata v1.57.0 h1:oDWBMjEIqyJO3GJEB+iwqxj47rxDK19OKzwaFEaE4sg= -go.opentelemetry.io/collector/pdata v1.57.0/go.mod h1:wZojinP6mNhLXudH8QXx/bjWzOsKMxi/FXwnk+12G/w= +go.opentelemetry.io/collector/featuregate v1.58.0 h1:Kh6Dpgbxywv/Q3D6qPehaSxNCxvr/U/ki7CL4y3udCo= +go.opentelemetry.io/collector/featuregate v1.58.0/go.mod h1:4ga1QBMPEejXXmpyJS8lmaRpknJ3Lb9Bvk6e420bUFU= +go.opentelemetry.io/collector/internal/testutil v0.152.0 h1:8LGwekR7mLcUDhT1ofLmdnrHRFuUa3U7PBd95ZvJEjQ= +go.opentelemetry.io/collector/internal/testutil v0.152.0/go.mod h1:Jkjs6rkqs973LqgZ0Fe3zrokQRKULYXPIf4HuqStiEE= +go.opentelemetry.io/collector/pdata v1.58.0 h1:5Lxut3NxKp87066Pzt+3q7+JUuFI5B3teCyLZIF8wIs= +go.opentelemetry.io/collector/pdata v1.58.0/go.mod h1:4vZtODINbC/JF3eGocnatdImzbRHseOywIcr+aULjCg= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 h1:hhPGP3zvvy1xWT9RTy970wlniSxFttBIsAK1gvMguJM= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0/go.mod h1:twJF7inoMza6kxMcF8JOdL3mPmtOZu7GEr34CUNE6Dg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo= @@ -456,8 +476,8 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -472,8 +492,8 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -497,18 +517,18 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= -golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -525,10 +545,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 h1:XF8+t6QQiS0o9ArVan/HW8Q7cycNPGsJf6GA2nXxYAg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60 h1:3WsB1FAbiRIf2tOxscWKs3pQBD9he1NsrnbhMuWfekc= +google.golang.org/genproto/googleapis/api v0.0.0-20260511170946-3700d4141b60/go.mod h1:7yoXV7RIh5gblj/xVYoogxAWvA9wUeVbpsK/M694l00= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60 h1:seT2EwLWM78plQ7wcDfuWBc/4FAEAXDDiaSol4ku4qo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260511170946-3700d4141b60/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ= google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -565,6 +585,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/gotestsum v1.13.0 h1:+Lh454O9mu9AMG1APV4o0y7oDYKyik/3kBOiCqiEpRo= +gotest.tools/gotestsum v1.13.0/go.mod h1:7f0NS5hFb0dWr4NtcsAsF0y1kzjEFfAil0HiBQJE03Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= modernc.org/cc/v4 v4.28.1 h1:XpLbkYVQ24E8tX5u8+yWGvaxerxkR/S4zqxI8ZoSBuc= diff --git a/integration/api-server-status_test.go b/integration/api-server-status_test.go index b6389395..c3398735 100644 --- a/integration/api-server-status_test.go +++ b/integration/api-server-status_test.go @@ -55,8 +55,8 @@ func TestStatusServer(t *testing.T) { } // defer the graceful stop of the service so that coverprofiles are written defer func() { - cmd.Process.Signal(os.Interrupt) - cmd.Wait() + _ = cmd.Process.Signal(os.Interrupt) + _ = cmd.Wait() }() // create the test cases @@ -85,7 +85,12 @@ func TestStatusServer(t *testing.T) { if i < 1 { t.Fatalf("could not connect to server: %s", err) } - if _, err := http.Get("http://localhost:8888"); err == nil { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8888", nil) + if err != nil { + t.Fatal(err) + } + + if _, err := http.DefaultClient.Do(req); err == nil { break } time.Sleep(100 * time.Millisecond) @@ -100,7 +105,12 @@ func TestStatusServer(t *testing.T) { t.Fatalf("could not construct a request url: %s", err) } - resp, err := http.Get(u) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("could not send request: %s", err) } @@ -112,21 +122,14 @@ func TestStatusServer(t *testing.T) { // Assert if tc.wantError { - if err == nil { - t.Error("expected error, but got nil") - } if got != nil { t.Errorf("expected nil response, but got: %+v", got) } } else { - if err != nil { - t.Errorf("unexpected error: %s", err) - } else { - t.Logf("response: %s", got) - var js json.RawMessage - if json.Unmarshal([]byte(got), &js) != nil { - t.Errorf("response is not valid json: %s", got) - } + t.Logf("response: %s", got) + var js json.RawMessage + if json.Unmarshal(got, &js) != nil { + t.Errorf("response is not valid json: %s", got) } } }) diff --git a/integration/grpc_test.go b/integration/grpc_test.go index d80d275a..d482c1bd 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -4,227 +4,297 @@ package integration_test import ( "context" + "errors" "fmt" - "net" + "os" + "os/exec" + "path/filepath" "testing" "time" "github.com/gofrs/uuid/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - slogctx "github.com/veqryn/slog-context" - stdgrpc "google.golang.org/grpc" - - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + healthpb "google.golang.org/grpc/health/grpc_health_v1" ) func TestGRPCServer(t *testing.T) { + const cmdName = "api-server" + const port = 9091 + // given ctx := t.Context() - port := 9091 - // create grpc server - srv, _, terminateFn, err := startServer(t, port) - require.NoError(t, err) - defer srv.Stop() - defer terminateFn(ctx) + istat := initInfra(t) + defer istat.Close(ctx) + + istat.Cfg.GRPC.Address = fmt.Sprintf(":%d", port) + + istat.PreparePostgres(t) + istat.PrepareValKey(t) + istat.PrepareConfig(t) + + currdir, err := os.Getwd() + require.NoError(t, err, "failed to get wd") + + t.Chdir(istat.Procdir) + + commandCtx, cancelCommand := context.WithCancel(ctx) + defer cancelCommand() + + cmd := exec.CommandContext(commandCtx, filepath.Join(currdir, "./session-manager"), cmdName) + cmd.WaitDelay = 5 * time.Second + cmd.Cancel = func() error { return cmd.Process.Signal(os.Interrupt) } + + cmdOutPath := filepath.Join(currdir, "grpc.log") + cmdOut, err := os.Create(cmdOutPath) + if err != nil { + t.Fatalf("failed to create an log file") + } + defer cmdOut.Close() + + cmd.Stdout = cmdOut + cmd.Stderr = cmdOut + + t.Logf("starting an app process. Logs will be saved into %s", cmdOutPath) + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start the server: %s", err) + } + + errCh := make(chan error) + go func() { + if err := cmd.Wait(); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("executing command: %w", err) + } + close(errCh) + }() // grpc client connection - conn, err := createClientConn(t, port) + cc, err := createClientConn(t, port) require.NoError(t, err) - defer conn.Close() + defer cc.Close() + + waitCtx, cancelWait := context.WithTimeout(commandCtx, 10*time.Second) + defer cancelWait() + if err := waitGRPCServerReady(waitCtx, cc); err != nil { + t.Fatalf("waiting for the server readiness: %s", err) + } - mappingClient := oidcmappingv1.NewServiceClient(conn) + trust := trustmappingv1.NewServiceClient(cc) - t.Run("ApplyOIDCMapping", func(t *testing.T) { + t.Run("ApplyTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) }) - t.Run("BlockOIDCMapping", func(t *testing.T) { + t.Run("BlockTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) - blockResp, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, blockResp.GetSuccess()) }) - t.Run("UnblockOIDCMapping", func(t *testing.T) { + t.Run("UnblockTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer1 := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer1, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer1, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) - unblockRes, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, unblockRes.GetSuccess()) }) - t.Run("RemoveOIDCMapping", func(t *testing.T) { + t.Run("RemoveTrustMapping", func(t *testing.T) { expJwks := "jks" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - removeRes, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) assert.NoError(t, err) assert.True(t, removeRes.GetSuccess()) }) - t.Run("ApplyOIDCMapping with multiple audiences", func(t *testing.T) { + t.Run("ApplyTrustMapping with multiple audiences", func(t *testing.T) { expJwks := "jks-multi" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() audiences := []string{"aud1", "aud2", "aud3"} - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: audiences, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: audiences, + }.Build(), + }.Build()) assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) }) - t.Run("ApplyOIDCMapping idempotent - applying same mapping twice", func(t *testing.T) { + t.Run("ApplyTrustMapping idempotent - applying same mapping twice", func(t *testing.T) { expJwks := "jks-idempotent" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() // First application - applyResp1, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp1, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp1.GetSuccess()) // Second application (should be idempotent) - applyResp2, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp2, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp2.GetSuccess()) }) - t.Run("BlockOIDCMapping idempotent - blocking twice", func(t *testing.T) { + t.Run("BlockTrustMapping idempotent - blocking twice", func(t *testing.T) { expJwks := "jks-block-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyResp, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"aud"}, - }) + applyResp, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"aud"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyResp.GetSuccess()) // First block - blockResp1, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp1, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockResp1.GetSuccess()) // Second block (should be idempotent) - blockResp2, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockResp2, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockResp2.GetSuccess()) }) - t.Run("UnblockOIDCMapping idempotent - unblocking twice", func(t *testing.T) { + t.Run("UnblockTrustMapping idempotent - unblocking twice", func(t *testing.T) { expJwks := "jks-unblock-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) // First unblock - unblockRes1, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes1, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes1.GetSuccess()) // Second unblock (should be idempotent) - unblockRes2, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes2, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes2.GetSuccess()) }) @@ -235,107 +305,118 @@ func TestGRPCServer(t *testing.T) { expIssuer := uuid.Must(uuid.NewV4()).String() // Apply mapping - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) // Block it - blockRes, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes.GetSuccess()) // Unblock it - unblockRes, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes.GetSuccess()) // Block again - blockRes2, err := mappingClient.BlockOIDCMapping(ctx, &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: expTenantID, - }) + blockRes2, err := trust.BlockTrustMapping(ctx, trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, blockRes2.GetSuccess()) // Unblock again - unblockRes2, err := mappingClient.UnblockOIDCMapping(ctx, &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: expTenantID, - }) + unblockRes2, err := trust.UnblockTrustMapping(ctx, trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, unblockRes2.GetSuccess()) }) - t.Run("RemoveOIDCMapping idempotent - removing twice", func(t *testing.T) { + t.Run("RemoveTrustMapping idempotent - removing twice", func(t *testing.T) { expJwks := "jks-remove-twice" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - applyRes, err := mappingClient.ApplyOIDCMapping(ctx, &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: expTenantID, - Issuer: expIssuer, - JwksUri: &expJwks, - Audiences: []string{"audience"}, - }) + applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: &expTenantID, + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: &expIssuer, + JwksUri: &expJwks, + Audiences: []string{"audience"}, + }.Build(), + }.Build()) + assert.NoError(t, err) assert.True(t, applyRes.GetSuccess()) // First remove - removeRes1, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes1, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, removeRes1.GetSuccess()) // Second remove - idempotence should not cause an error - removeRes2, err := mappingClient.RemoveOIDCMapping(ctx, &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: expTenantID, - }) + removeRes2, err := trust.RemoveTrustMapping(ctx, trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: &expTenantID, + }.Build()) + assert.NoError(t, err) assert.True(t, removeRes2.GetSuccess()) }) + + cancelCommand() + select { + case err := <-errCh: + if err != nil { + t.Fatalf("error executing command: %s", err) + } + case <-time.After(10 * time.Second): + t.Fatalf("timeout exceeded") + } } -func createClientConn(t *testing.T, port int) (*stdgrpc.ClientConn, error) { +func createClientConn(t *testing.T, port int) (*grpc.ClientConn, error) { t.Helper() - conn, err := stdgrpc.NewClient(fmt.Sprintf("localhost:%d", port), - stdgrpc.WithTransportCredentials(insecure.NewCredentials()), + conn, err := grpc.NewClient(fmt.Sprintf("localhost:%d", port), + grpc.WithTransportCredentials(insecure.NewCredentials()), ) return conn, err } -func startServer(t *testing.T, port int) (*stdgrpc.Server, *trust.Service, func(context.Context), error) { - t.Helper() - ctx := t.Context() - // start postgres - db, _, terminateFn := postgrestest.Start(ctx) - trustRepo := trustsql.NewRepository(db) - service := trust.NewService(trustRepo) - - lstConf := net.ListenConfig{} - lis, err := lstConf.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - return nil, nil, nil, err - } +func waitGRPCServerReady(ctx context.Context, cc *grpc.ClientConn) error { + healthClient := healthpb.NewHealthClient(cc) - srv := stdgrpc.NewServer() - oidcmappingv1.RegisterServiceServer(srv, grpc.NewOIDCMappingServer(service)) - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, nil, trustRepo, time.Hour, "")) - - // start - go func() { - err = srv.Serve(lis) + const maxAttempts = 100 + for range maxAttempts { + out, err := healthClient.Check(ctx, new(healthpb.HealthCheckRequest), grpc.WaitForReady(true)) if err != nil { - slogctx.Error(ctx, "error while starting server", "error", err) + return fmt.Errorf("checking health status: %w", err) } - }() - return srv, service, terminateFn, nil + if out.GetStatus() == healthpb.HealthCheckResponse_SERVING { + return nil + } + } + + return errors.New("exceeded max attempts number") } diff --git a/integration/infra_test.go b/integration/infra_test.go index fbb6806a..858e9d5c 100644 --- a/integration/infra_test.go +++ b/integration/infra_test.go @@ -9,25 +9,43 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/goccy/go-yaml" "github.com/moby/moby/api/types/network" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/stretchr/testify/require" + "github.com/valkey-io/valkey-go" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/dbtest/postgrestest" "github.com/openkcm/session-manager/internal/dbtest/valkeytest" + "github.com/openkcm/session-manager/modules/database/pgxpool" ) type closeFunc func(ctx context.Context) +type Config struct { + config.Config `yaml:",inline"` + + Database pgxpool.PostgresModule `yaml:"database"` +} + +func loadExtendedConfig(t *testing.T, dir string) *Config { + t.Helper() + + cfg, err := config.Load("", dir) + require.NoError(t, err, "failed to load config") + + return &Config{Config: *cfg} +} + type infraStat struct { PostgresPort network.Port ValKeyPort network.Port ConfigFilePath string Procdir string - Cfg config.Config + Cfg *Config closeFuncs []closeFunc } @@ -38,21 +56,22 @@ func initInfra(t *testing.T) (istat infraStat) { // Since the config is read from the file $PWD/config.yaml, // we're running a process in a temporary subdirectory // so that we aren't interfering with the other tests. - procDir, err := os.MkdirTemp("", "*") - require.NoError(t, err, "failed to create a temp dir") + procDir := t.TempDir() istat.Procdir = procDir istat.ConfigFilePath = filepath.Join(procDir, "config.yaml") - err = os.WriteFile(istat.ConfigFilePath, []byte(validConfig), fs.ModePerm) + err := os.WriteFile(istat.ConfigFilePath, []byte(validConfig), fs.ModePerm) require.NoError(t, err, "failed to write config file") - err = commoncfg.LoadConfig(&istat.Cfg, nil, istat.Procdir) - require.NoError(t, err, "failed to load config") + istat.Cfg = loadExtendedConfig(t, istat.Procdir) // Let OS choose a free port istat.Cfg.HTTP.Address = "unix://" + filepath.Join(procDir, "unix.sock") istat.Cfg.GRPC.Address = ":0" istat.Cfg.Logger.Format = commoncfg.TextLoggerFormat + istat.Cfg.Logger.Level = "debug" + istat.Cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + istat.Cfg.Logger.Formatter.Time.Pattern = time.Stamp // There's a hard limit of 108 symbols on a unix socket filepath on Linux/macOS. if len(istat.Cfg.HTTP.Address) > 108 { @@ -69,35 +88,33 @@ func (istat *infraStat) PreparePostgres(t *testing.T) { const dbpass = "secret" const dbname = "session_manager" - wd, err := os.Getwd() - require.NoError(t, err, "getting wd") - pgClient, pgPort, pgTerminate := postgrestest.Start(t.Context()) pgClient.Close() istat.PostgresPort = pgPort istat.closeFuncs = append(istat.closeFuncs, pgTerminate) + istat.Cfg.Database.Mod = "database.module.pgxpool" istat.Cfg.Database.Name = dbname istat.Cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} istat.Cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} istat.Cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} istat.Cfg.Database.Port = pgPort.Port() - istat.Cfg.Migrate.Source = "file://" + filepath.Join(wd, "../sql") } -func (istat *infraStat) PrepareValKey(t *testing.T) { +func (istat *infraStat) PrepareValKey(t *testing.T) valkey.Client { t.Helper() vkClient, vkPort, vkTerminate := valkeytest.Start(t.Context()) - vkClient.Close() istat.ValKeyPort = vkPort - istat.closeFuncs = append(istat.closeFuncs, vkTerminate) + istat.closeFuncs = append(istat.closeFuncs, vkTerminate, func(_ context.Context) { vkClient.Close() }) istat.Cfg.ValKey.Host = commoncfg.SourceRef{Source: "embedded", Value: net.JoinHostPort("localhost", vkPort.Port())} istat.Cfg.ValKey.User = commoncfg.SourceRef{Source: "embedded", Value: ""} istat.Cfg.ValKey.Password = commoncfg.SourceRef{Source: "embedded", Value: ""} + + return vkClient } // PrepareConfig writes a config file for running the test into the ConfigFilePath. diff --git a/integration/migrate_test.go b/integration/migrate_test.go index b21490dd..4761470f 100644 --- a/integration/migrate_test.go +++ b/integration/migrate_test.go @@ -8,13 +8,12 @@ import ( "os/exec" "path/filepath" "testing" + "time" "github.com/go-viper/mapstructure/v2" "github.com/goccy/go-yaml" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/testcontainers/testcontainers-go/modules/postgres" - - "github.com/openkcm/session-manager/internal/config" ) func TestMigrate(t *testing.T) { @@ -39,6 +38,7 @@ func TestMigrate(t *testing.T) { if err != nil { t.Fatalf("failed to start PostgreSQL: %s", err) } + defer func() { _ = pgContainer.Terminate(ctx) }() port, err := pgContainer.MappedPort(ctx, "5432") if err != nil { @@ -46,40 +46,51 @@ func TestMigrate(t *testing.T) { } // Prepare config - _ = os.MkdirAll(testdir, fs.ModePerm) - defer os.RemoveAll(testdir) - - err = os.WriteFile(configFilePath, []byte(validConfig), fs.ModePerm) + currdir, err := os.Getwd() if err != nil { - t.Fatalf("failed to write config file: %s", err) + t.Fatalf("failed to get wd: %s", err) } - defer os.Remove(configFilePath) - var cfg config.Config - err = commoncfg.LoadConfig(&cfg, nil, testdir) - if err != nil { - t.Fatalf("failed to load config: %s", err) - } + abstestdir := filepath.Join(currdir, testdir) + _ = os.MkdirAll(abstestdir, fs.ModePerm) + defer os.RemoveAll(abstestdir) - currdir, err := os.Getwd() + absConfigFilePath := filepath.Join(currdir, configFilePath) + err = os.WriteFile(absConfigFilePath, []byte(validConfig), fs.ModePerm) if err != nil { - t.Fatalf("failed to get wd: %s", err) + t.Fatalf("failed to write config file: %s", err) } + cfg := loadExtendedConfig(t, abstestdir) + cfg.Logger.Level = "debug" + cfg.Logger.Format = commoncfg.TextLoggerFormat + cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + cfg.Logger.Formatter.Time.Pattern = time.Stamp + cfg.Database.Name = dbname cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} cfg.Database.Port = port.Port() - cfg.Migrate.Source = "file://" + filepath.Join(currdir, "../sql") cfgMap := make(map[any]any) - err = mapstructure.Decode(cfg, &cfgMap) - if err != nil { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &cfgMap, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + WeaklyTypedInput: true, + TagName: "yaml", + SquashTagOption: "inline", + }) + if err != nil { + t.Fatalf("failed to create mapstructure decoder: %s", err) + } + if err := decoder.Decode(cfg); err != nil { t.Fatalf("failed to decode mapstructure: %s", err) } - f, err := os.Create(configFilePath) + f, err := os.Create(absConfigFilePath) if err != nil { t.Fatalf("failed to create config file: %s", err) } @@ -90,9 +101,7 @@ func TestMigrate(t *testing.T) { t.Fatalf("failed to write config: %s", err) } - wd, _ := os.Getwd() - t.Chdir(testdir) - defer os.Chdir(wd) + t.Chdir(abstestdir) // Run the migrations cmd := exec.CommandContext(ctx, filepath.Join(currdir, "./session-manager"), cmdName) @@ -145,40 +154,51 @@ func TestMigrateIdempotent(t *testing.T) { } // Prepare config - _ = os.MkdirAll(testdir, fs.ModePerm) - defer os.RemoveAll(testdir) - - err = os.WriteFile(configFilePath, []byte(validConfig), fs.ModePerm) + currdir, err := os.Getwd() if err != nil { - t.Fatalf("failed to write config file: %s", err) + t.Fatalf("failed to get wd: %s", err) } - defer os.Remove(configFilePath) - var cfg config.Config - err = commoncfg.LoadConfig(&cfg, nil, testdir) - if err != nil { - t.Fatalf("failed to load config: %s", err) - } + abstestdir := filepath.Join(currdir, testdir) + _ = os.MkdirAll(abstestdir, fs.ModePerm) + defer os.RemoveAll(abstestdir) - currdir, err := os.Getwd() + absConfigFilePath := filepath.Join(currdir, configFilePath) + err = os.WriteFile(absConfigFilePath, []byte(validConfig), fs.ModePerm) if err != nil { - t.Fatalf("failed to get wd: %s", err) + t.Fatalf("failed to write config file: %s", err) } + cfg := loadExtendedConfig(t, abstestdir) + cfg.Logger.Level = "debug" + cfg.Logger.Format = commoncfg.TextLoggerFormat + cfg.Logger.Formatter.Time.Type = commoncfg.PatternTimeLogger + cfg.Logger.Formatter.Time.Pattern = time.Stamp + cfg.Database.Name = dbname cfg.Database.User = commoncfg.SourceRef{Source: "embedded", Value: dbuser} cfg.Database.Password = commoncfg.SourceRef{Source: "embedded", Value: dbpass} cfg.Database.Host = commoncfg.SourceRef{Source: "embedded", Value: "localhost"} cfg.Database.Port = port.Port() - cfg.Migrate.Source = "file://" + filepath.Join(currdir, "../sql") cfgMap := make(map[any]any) - err = mapstructure.Decode(cfg, &cfgMap) - if err != nil { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: &cfgMap, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + WeaklyTypedInput: true, + TagName: "yaml", + SquashTagOption: "inline", + }) + if err != nil { + t.Fatalf("failed to create mapstructure decoder: %s", err) + } + if err := decoder.Decode(cfg); err != nil { t.Fatalf("failed to decode mapstructure: %s", err) } - f, err := os.Create(configFilePath) + f, err := os.Create(absConfigFilePath) if err != nil { t.Fatalf("failed to create config file: %s", err) } @@ -189,9 +209,7 @@ func TestMigrateIdempotent(t *testing.T) { t.Fatalf("failed to write config: %s", err) } - wd, _ := os.Getwd() - t.Chdir(testdir) - defer os.Chdir(wd) + t.Chdir(abstestdir) // Run migrations the first time cmd := exec.CommandContext(ctx, filepath.Join(currdir, "./session-manager"), cmdName) diff --git a/integration/session_grpc_test.go b/integration/session_grpc_test.go index 7a235e4e..d5658adf 100644 --- a/integration/session_grpc_test.go +++ b/integration/session_grpc_test.go @@ -4,8 +4,11 @@ package integration_test import ( "context" + "errors" "fmt" - "net" + "os" + "os/exec" + "path/filepath" "testing" "time" @@ -14,33 +17,76 @@ import ( "github.com/stretchr/testify/require" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - slogctx "github.com/veqryn/slog-context" - stdgrpc "google.golang.org/grpc" - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/dbtest/valkeytest" - "github.com/openkcm/session-manager/internal/grpc" "github.com/openkcm/session-manager/internal/session" sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" - "github.com/openkcm/session-manager/internal/trust/trustsql" ) func TestSessionGRPC(t *testing.T) { + const cmdName = "api-server" + const port = 9092 + // given ctx := t.Context() - port := 9092 - // create grpc server with session support - srv, sessionRepo, terminateFn, err := startSessionServer(t, port) - require.NoError(t, err) - defer srv.Stop() - defer terminateFn(ctx) + istat := initInfra(t) + defer istat.Close(ctx) + + istat.Cfg.GRPC.Address = fmt.Sprintf(":%d", port) + + istat.PreparePostgres(t) + valkeyClient := istat.PrepareValKey(t) + istat.PrepareConfig(t) + + sessionRepo := sessionvalkey.NewRepository(valkeyClient, "session") + + currdir, err := os.Getwd() + require.NoError(t, err, "failed to get wd") + + t.Chdir(istat.Procdir) + + commandCtx, cancelCommand := context.WithCancel(ctx) + defer cancelCommand() + + cmd := exec.CommandContext(commandCtx, filepath.Join(currdir, "./session-manager"), cmdName) + cmd.WaitDelay = 5 * time.Second + cmd.Cancel = func() error { return cmd.Process.Signal(os.Interrupt) } + + cmdOutPath := filepath.Join(currdir, "session-grpc.log") + cmdOut, err := os.Create(cmdOutPath) + if err != nil { + t.Fatalf("failed to create an log file") + } + defer cmdOut.Close() + + cmd.Stdout = cmdOut + cmd.Stderr = cmdOut + + t.Logf("starting an app process. Logs will be saved into %s", cmdOutPath) + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start the server: %s", err) + } + + errCh := make(chan error) + go func() { + if err := cmd.Wait(); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("executing command: %w", err) + } + close(errCh) + }() // grpc client connection - conn, err := createClientConn(t, port) + cc, err := createClientConn(t, port) require.NoError(t, err) - defer conn.Close() - sessionClient := sessionv1.NewServiceClient(conn) + defer cc.Close() + + waitCtx, cancelWait := context.WithTimeout(commandCtx, 10*time.Second) + defer cancelWait() + if err := waitGRPCServerReady(waitCtx, cc); err != nil { + t.Fatalf("waiting for the server readiness: %s", err) + } + + sessionClient := sessionv1.NewServiceClient(cc) t.Run("GetSession - session not found", func(t *testing.T) { resp, err := sessionClient.GetSession(ctx, &sessionv1.GetSessionRequest{ @@ -171,43 +217,3 @@ func TestSessionGRPC(t *testing.T) { assert.False(t, resp.GetValid()) }) } - -func startSessionServer(t *testing.T, port int) (*stdgrpc.Server, session.Repository, func(context.Context), error) { - t.Helper() - ctx := t.Context() - - // start postgres - db, _, terminatePG := postgrestest.Start(ctx) - - // start valkey - valkeyClient, _, terminateValkey := valkeytest.Start(ctx) - - terminateFn := func(ctx context.Context) { - terminatePG(ctx) - terminateValkey(ctx) - db.Close() - valkeyClient.Close() - } - - trustRepo := trustsql.NewRepository(db) - sessionRepo := sessionvalkey.NewRepository(valkeyClient, "session") - - lstConf := net.ListenConfig{} - lis, err := lstConf.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - return nil, nil, nil, err - } - - srv := stdgrpc.NewServer() - sessionv1.RegisterServiceServer(srv, grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "")) - - // start - go func() { - err = srv.Serve(lis) - if err != nil { - slogctx.Error(ctx, "error while starting session server", "error", err) - } - }() - - return srv, sessionRepo, terminateFn, nil -} diff --git a/internal/business/business.go b/internal/business/business.go index 1a9780d2..b08d17eb 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -7,30 +7,40 @@ import ( "log/slog" "sync" - "github.com/exaring/otelpgx" - "github.com/jackc/pgx/v5/pgxpool" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/valkey-io/valkey-go" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/business/server" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/grpc" "github.com/openkcm/session-manager/internal/session" sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" ) -const clientAuthTypeInsecure = "insecure" +const ( + insecure = "insecure" + mtls = "mtls" + clientSecret = "client_secret" // An alias to clientSecretPost. Prefer using clientSecretPost. + clientSecretPost = "client_secret_post" +) // Main starts both API servers func Main(ctx context.Context, cfg *config.Config) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + c, cancelCause := sessionmanager.NewContext(ctx) + defer cancelCause(nil) + + if _, err := c.LoadModule(&cfg.Database); err != nil { + return fmt.Errorf("loading database module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Trust); err != nil { + return fmt.Errorf("loading trust module: %w", err) + } // errChan is used to capture the first error and shutdown the servers. errChan := make(chan error, 1) @@ -40,12 +50,12 @@ func Main(ctx context.Context, cfg *config.Config) error { // start public HTTP REST API server wg.Go(func() { - errChan <- publicMain(ctx, cfg) + errChan <- publicMain(c, cfg) }) // start internal gRPC API server wg.Go(func() { - errChan <- internalMain(ctx, cfg) + errChan <- internalMain(c, cfg) }) // wait for any error to initiate the shutdown @@ -53,7 +63,7 @@ func Main(ctx context.Context, cfg *config.Config) error { if err != nil { slogctx.Error(ctx, "Shutting down servers", "error", err) } - cancel() + cancelCause(err) // wait for all servers to shutdown wg.Wait() @@ -62,7 +72,7 @@ func Main(ctx context.Context, cfg *config.Config) error { } // publicMain starts the HTTP REST public API server. -func publicMain(ctx context.Context, cfg *config.Config) error { +func publicMain(ctx *sessionmanager.Context, cfg *config.Config) error { csrfSecret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.CSRFSecret) if err != nil { return fmt.Errorf("loading csrf token from source ref: %w", err) @@ -73,7 +83,15 @@ func publicMain(ctx context.Context, cfg *config.Config) error { cfg.SessionManager.CSRFSecretParsed = csrfSecret - sessionManager, closeFn, err := initSessionManager(ctx, cfg) + trustMod, err := ctx.GetModule(cfg.Trust.Module()) + if err != nil { + return fmt.Errorf("getting trust module: %w", err) + } + + //nolint:forcetypeassert + trust := trustMod.(sessionmanager.Trust) + + sessionManager, closeFn, err := initSessionManager(ctx, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } @@ -84,14 +102,7 @@ func publicMain(ctx context.Context, cfg *config.Config) error { } // internalMain starts the gRPC private API server. -func internalMain(ctx context.Context, cfg *config.Config) error { - // Create trust service - trustRepo, err := trustRepoFromConfig(ctx, cfg) - if err != nil { - return fmt.Errorf("failed to create trust service: %w", err) - } - trustService := trust.NewService(trustRepo) - +func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { // Create session repository valkeyClient, err := valkeyClientFromConfig(cfg) if err != nil { @@ -105,26 +116,28 @@ func internalMain(ctx context.Context, cfg *config.Config) error { return fmt.Errorf("failed to create a credentials builder: %w", err) } + trustMod, err := ctx.GetModule(cfg.Trust.Module()) + if err != nil { + return fmt.Errorf("getting trust module: %w", err) + } + + //nolint:forcetypeassert + trust := trustMod.(sessionmanager.Trust) + // Initialize the gRPC servers. - oidcmappingsrv := grpc.NewOIDCMappingServer(trustService) + oidcmappingsrv := grpc.NewTrustMappingServer(trust) sessionsrv := grpc.NewSessionServer(ctx, sessionRepo, - trustRepo, + trust, cfg.SessionManager.IdleSessionTimeout, cfg.SessionManager.ClientAuth.ClientID, - grpc.WithQueryParametersIntrospect(cfg.SessionManager.AdditionalQueryParametersIntrospect), grpc.WithTransportCredentials(credsBuilder), ) + return server.StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) } -func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Manager, closeFn func(), _ error) { - // Create trust repository - trustRepo, err := trustRepoFromConfig(ctx, cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to create trust repository: %w", err) - } - +func initSessionManager(ctx context.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { // Create session repository valkeyClient, err := valkeyClientFromConfig(cfg) if err != nil { @@ -144,7 +157,7 @@ func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Man sessManager, err := session.NewManager(ctx, &cfg.SessionManager, - trustRepo, + trust, sessionRepo, auditLogger, session.WithTransportCredentials(credsBuilder), @@ -156,31 +169,6 @@ func initSessionManager(ctx context.Context, cfg *config.Config) (_ *session.Man return sessManager, valkeyClient.Close, nil } -func trustRepoFromConfig(ctx context.Context, cfg *config.Config) (*trustsql.Repository, error) { - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return nil, fmt.Errorf("failed to make dsn from config: %w", err) - } - - pgxpoolCfg, err := pgxpool.ParseConfig(connStr) - if err != nil { - return nil, fmt.Errorf("parsing pgxpool config: %w", err) - } - - pgxpoolCfg.ConnConfig.Tracer = otelpgx.NewTracer() - - db, err := pgxpool.NewWithConfig(ctx, pgxpoolCfg) - if err != nil { - return nil, fmt.Errorf("failed to initialise pgxpool connection: %w", err) - } - - if err := otelpgx.RecordStats(db); err != nil { - return nil, fmt.Errorf("recording database stat: %w", err) - } - - return trustsql.NewRepository(db), nil -} - func valkeyClientFromConfig(cfg *config.Config) (valkey.Client, error) { valkeyHost, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Host) if err != nil { @@ -221,14 +209,14 @@ func valkeyClientFromConfig(cfg *config.Config) (valkey.Client, error) { func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { switch cfg.SessionManager.ClientAuth.Type { - case "mtls": + case mtls: tlsConfig, err := commoncfg.LoadMTLSConfig(cfg.SessionManager.ClientAuth.MTLS) if err != nil { return nil, fmt.Errorf("failed to load mTLS config: %w", err) } return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil - case "client_secret", "client_secret_post": + case clientSecretPost, clientSecret: secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) if err != nil { return nil, fmt.Errorf("failed to load client secret: %w", err) @@ -237,7 +225,7 @@ func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { return func(clientID string) credentials.TransportCredentials { return credentials.NewClientSecretPost(clientID, string(secret)) }, nil - case clientAuthTypeInsecure: + case insecure: slog.Warn("insecure credentials are used. Do not use this in production") return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil default: diff --git a/internal/business/business_test.go b/internal/business/business_test.go index fa9ecb1d..cb14c057 100644 --- a/internal/business/business_test.go +++ b/internal/business/business_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" ) @@ -65,7 +66,7 @@ func TestLoadHTTPClient_Insecure(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, + Type: insecure, ClientID: "test-client", }, }, @@ -277,90 +278,6 @@ func TestValkeyClientFromConfig_WithMTLS(t *testing.T) { assert.Contains(t, err.Error(), "failed to load valkey mTLS config from secret ref") } -func TestTrustRepoFromConfig_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := trustRepoFromConfig(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to make dsn from config") -} - -func TestInitSessionManager_InvalidOIDCConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - assert.Contains(t, err.Error(), "failed to create trust repository") -} - -func TestInitSessionManager_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - // Will fail on either DB connection or valkey config - // Error details depend on which step fails -} - -func TestInitSessionManager_InvalidHTTPClientConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "invalid-type", - }, - }, - } - - _, closeFn, err := initSessionManager(t.Context(), cfg) - assert.Error(t, err) - assert.Nil(t, closeFn) - // Should fail on one of the earlier steps (DB or valkey) or on HTTP client - // Error details depend on which step fails -} - func TestPublicMain_InvalidCSRFSecret(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ @@ -368,7 +285,10 @@ func TestPublicMain_InvalidCSRFSecret(t *testing.T) { }, } - err := publicMain(t.Context(), cfg) + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + err := publicMain(ctx, cfg) assert.Error(t, err) assert.Contains(t, err.Error(), "loading csrf token from source ref") } @@ -380,36 +300,16 @@ func TestPublicMain_ShortCSRFSecret(t *testing.T) { }, } - err := publicMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "CSRF secret must be at least 32 bytes") -} + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) -func TestInternalMain_InvalidOIDCConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := internalMain(t.Context(), cfg) + err := publicMain(ctx, cfg) assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create trust service") + assert.Contains(t, err.Error(), "CSRF secret must be at least 32 bytes") } func TestInternalMain_InvalidValkeyConfig(t *testing.T) { cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, ValKey: config.ValKey{ Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, @@ -417,7 +317,10 @@ func TestInternalMain_InvalidValkeyConfig(t *testing.T) { }, } - err := internalMain(t.Context(), cfg) + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + err := internalMain(ctx, cfg) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to create valkey client") // Could fail on OIDC (DB connection) or valkey @@ -440,39 +343,6 @@ func TestMain_PublicServerInvalidCSRF(t *testing.T) { SessionManager: config.SessionManager{ CSRFSecret: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, }, - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := Main(t.Context(), cfg) - assert.Error(t, err) -} - -func TestMain_InternalServerInvalidDatabase(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - CSRFSecret: commoncfg.SourceRef{Source: "embedded", Value: "this-is-a-very-long-secret-that-is-at-least-32-bytes-long"}, - ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, - }, - }, - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, } err := Main(t.Context(), cfg) diff --git a/internal/business/housekeeper.go b/internal/business/housekeeper.go index 197fc42a..18fe69d3 100644 --- a/internal/business/housekeeper.go +++ b/internal/business/housekeeper.go @@ -7,31 +7,48 @@ import ( slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" ) // HousekeeperMain starts the house keeping jobs func HousekeeperMain(ctx context.Context, cfg *config.Config) error { - sessionManager, closeFn, err := initSessionManager(ctx, cfg) + c, cancelCause := sessionmanager.NewContext(ctx) + defer cancelCause(nil) + + _, err := c.LoadModule(&cfg.Database) + if err != nil { + return fmt.Errorf("loading database module: %w", err) + } + + trustMod, err := c.LoadModule(&cfg.Trust) + if err != nil { + return fmt.Errorf("loading trust module: %w", err) + } + + //nolint:forcetypeassert + trust := trustMod.(sessionmanager.Trust) + + sessionManager, closeFn, err := initSessionManager(ctx, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } defer closeFn() // Start the housekeeper loop - c := time.Tick(cfg.Housekeeper.TriggerInterval) + tick := time.Tick(cfg.Housekeeper.TriggerInterval) refreshTriggerInterval := cfg.Housekeeper.TokenRefreshTriggerInterval concurrencyLimit := cfg.Housekeeper.ConcurrencyLimit for { - err := sessionManager.TriggerHousekeeping(ctx, concurrencyLimit, refreshTriggerInterval) + err := sessionManager.TriggerHousekeeping(c, concurrencyLimit, refreshTriggerInterval) if err != nil { slogctx.Error(ctx, "Error during session housekeeping", "error", err) } select { - case <-c: + case <-tick: continue - case <-ctx.Done(): + case <-c.Done(): return nil } } diff --git a/internal/business/housekeeper_test.go b/internal/business/housekeeper_test.go index 820ca9a4..cc960403 100644 --- a/internal/business/housekeeper_test.go +++ b/internal/business/housekeeper_test.go @@ -10,52 +10,8 @@ import ( "github.com/openkcm/session-manager/internal/config" ) -func TestHousekeeperMain_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := HousekeeperMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to initialise the session manager") -} - -func TestHousekeeperMain_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := HousekeeperMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to initialise the session manager") -} - func TestHousekeeperMain_CancelledContext(t *testing.T) { cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, ValKey: config.ValKey{ Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, @@ -63,7 +19,7 @@ func TestHousekeeperMain_CancelledContext(t *testing.T) { }, SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: clientAuthTypeInsecure, + Type: insecure, }, }, } diff --git a/internal/business/migrate.go b/internal/business/migrate.go index 9c70059d..a03d5271 100644 --- a/internal/business/migrate.go +++ b/internal/business/migrate.go @@ -4,57 +4,41 @@ import ( "context" "fmt" - "github.com/XSAM/otelsql" - "github.com/pressly/goose/v3" - "github.com/samber/oops" - // Register pgx driver _ "github.com/jackc/pgx/v5/stdlib" slogctx "github.com/veqryn/slog-context" - semconv "go.opentelemetry.io/otel/semconv/v1.37.0" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" - migrations "github.com/openkcm/session-manager/sql" ) // MigrateMain starts the database migration func MigrateMain(ctx context.Context, cfg *config.Config) error { - const dialect = "pgx" - dbSystemName := semconv.DBSystemNamePostgreSQL - - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return fmt.Errorf("making connection string from config: %w", err) - } - - db, err := otelsql.Open(dialect, connStr, otelsql.WithAttributes(dbSystemName)) - if err != nil { - return oops.In("main").Wrapf(err, "opening DB connection") - } - - reg, err := otelsql.RegisterDBStatsMetrics(db, otelsql.WithAttributes(dbSystemName)) - if err != nil { - return fmt.Errorf("registering db stats metrics: %w", err) - } + c, cancel := sessionmanager.NewContext(ctx) + var err error defer func() { - err = reg.Unregister() - if err != nil { - slogctx.Error(ctx, "failed to unregister db stats metrics", "error", err) - } + cancel(err) }() - goose.SetBaseFS(migrations.FS) - - err = goose.SetDialect(dialect) + slogctx.Debug(c, "loading db") + _, err = c.LoadModule(&cfg.Database) if err != nil { - return fmt.Errorf("setting goose dialect: %w", err) + return fmt.Errorf("loading database module: %w", err) } - err = goose.UpContext(ctx, db, ".") + slogctx.Debug(c, "loading migrate") + mod, err := c.LoadModule(&cfg.Migrate) if err != nil { - return fmt.Errorf("applying migrations: %w", err) + return fmt.Errorf("loading migration module: %w", err) + } + + //nolint:forcetypeassert + migrate := mod.(sessionmanager.Migrate) + slogctx.Debug(c, "executing migration") + if err := migrate.Migrate(ctx); err != nil { + return fmt.Errorf("executing migrations: %w", err) } return nil diff --git a/internal/business/migrate_test.go b/internal/business/migrate_test.go deleted file mode 100644 index ea28d0b7..00000000 --- a/internal/business/migrate_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package business - -import ( - "testing" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/config" -) - -func TestMigrateMain_InvalidDatabaseConfig(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} - -func TestMigrateMain_InvalidUserRef(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} - -func TestMigrateMain_InvalidPasswordRef(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost"}, - Port: "5432", - Name: "testdb", - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - }, - } - - err := MigrateMain(t.Context(), cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") -} diff --git a/internal/business/server/grpc_server.go b/internal/business/server/grpc_server.go index 64bfccae..596bbc20 100644 --- a/internal/business/server/grpc_server.go +++ b/internal/business/server/grpc_server.go @@ -7,8 +7,8 @@ import ( "github.com/openkcm/common-sdk/pkg/commongrpc" "github.com/samber/oops" - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" slogctx "github.com/veqryn/slog-context" "github.com/openkcm/session-manager/internal/config" @@ -16,13 +16,13 @@ import ( ) func StartGRPCServer(ctx context.Context, cfg *config.Config, - oidcmappingsrv *grpc.OIDCMappingServer, + oidcmappingsrv *grpc.TrustMappingServer, sessionsrv *grpc.SessionServer, ) error { grpcServer := commongrpc.NewServer(ctx, &cfg.GRPC.GRPCServer) // Register OIDC mapping server for the regional tenant manager - oidcmappingv1.RegisterServiceServer(grpcServer, oidcmappingsrv) + trustmappingv1.RegisterServiceServer(grpcServer, oidcmappingsrv) // Register Session server for ExtAuthZ sessionv1.RegisterServiceServer(grpcServer, sessionsrv) diff --git a/internal/business/server/grpc_server_test.go b/internal/business/server/grpc_server_test.go index 10440854..01cb928d 100644 --- a/internal/business/server/grpc_server_test.go +++ b/internal/business/server/grpc_server_test.go @@ -26,7 +26,7 @@ func TestStartGRPCServer_ContextCancellation(t *testing.T) { } // Create minimal server instances - oidcmappingsrv := grpc.NewOIDCMappingServer(nil) + oidcmappingsrv := grpc.NewTrustMappingServer(nil) sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0, "") // Start the server in a goroutine diff --git a/internal/business/server/openapi.go b/internal/business/server/openapi.go index 5bf9752e..d9695139 100644 --- a/internal/business/server/openapi.go +++ b/internal/business/server/openapi.go @@ -16,7 +16,7 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" ) diff --git a/internal/business/server/openapi_test.go b/internal/business/server/openapi_test.go index 59fccdd6..59e9b207 100644 --- a/internal/business/server/openapi_test.go +++ b/internal/business/server/openapi_test.go @@ -15,7 +15,7 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" ) diff --git a/internal/cmdutils/cmdutils.go b/internal/cmdutils/cmdutils.go index 2f37ef7f..ebcd1864 100644 --- a/internal/cmdutils/cmdutils.go +++ b/internal/cmdutils/cmdutils.go @@ -3,11 +3,9 @@ package cmdutils import ( "context" "fmt" - "log/slog" "syscall" "time" - "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/openkcm/common-sdk/pkg/health" "github.com/openkcm/common-sdk/pkg/logger" "github.com/openkcm/common-sdk/pkg/otlp" @@ -35,7 +33,11 @@ func CobraCommand( Long: long, SilenceUsage: true, RunE: func(cmd *cobra.Command, _ []string) error { - cfg, err := loadConfig(buildInfo) + cfg, err := config.Load(buildInfo, + "/etc/session-manager/", + "$HOME/.session-manager/", + "./", + ) if err != nil { return fmt.Errorf("loading config: %w", err) } @@ -65,7 +67,7 @@ func run(ctx context.Context, withTelemetry, withStatusServer bool, fn func(cont return oops.In("main"). Wrapf(err, "Failed to initialise the logger") } - slogctx.Debug(ctx, "Starting the application", slog.Any("config", cfg)) + slogctx.Debug(ctx, "Starting the application") // OpenTelemetry if withTelemetry { @@ -95,33 +97,6 @@ func run(ctx context.Context, withTelemetry, withStatusServer bool, fn func(cont return nil } -func loadConfig(buildInfo string) (*config.Config, error) { - defaultValues := map[string]any{} - cfg := &config.Config{} - - err := commoncfg.LoadConfig( - cfg, - defaultValues, - "/etc/session-manager", - "$HOME/.session-manager", - ".", - ) - if err != nil { - return nil, fmt.Errorf("loading configuration: %w", err) - } - - // Update Version - err = commoncfg.UpdateConfigVersion( - &cfg.BaseConfig, - buildInfo, - ) - if err != nil { - return nil, fmt.Errorf("updating the version configuration: %w", err) - } - - return cfg, nil -} - func statusListener(ctx context.Context, state health.State) { subctx := slogctx.With(ctx, "status", state.Status) //nolint:fatcontext @@ -136,11 +111,6 @@ func statusListener(ctx context.Context, state health.State) { } func startStatusServer(ctx context.Context, cfg *config.Config) error { - connStr, err := config.MakeConnStr(cfg.Database) - if err != nil { - return fmt.Errorf("making connection string from config: %w", err) - } - liveness := status.WithLiveness( health.NewHandler( health.NewChecker(health.WithDisabledAutostart()), @@ -150,7 +120,6 @@ func startStatusServer(ctx context.Context, cfg *config.Config) error { healthOptions := []health.Option{ health.WithDisabledAutostart(), health.WithTimeout(healthStatusTimeout), - health.WithDatabaseChecker("pgx", connStr), health.WithStatusListener(statusListener), } @@ -160,8 +129,7 @@ func startStatusServer(ctx context.Context, cfg *config.Config) error { ), ) - err = status.Start(ctx, &cfg.BaseConfig, liveness, readiness) - if err != nil { + if err := status.Start(ctx, &cfg.BaseConfig, liveness, readiness); err != nil { return fmt.Errorf("starting status server: %w", err) } diff --git a/internal/cmdutils/cmdutils_test.go b/internal/cmdutils/cmdutils_test.go index a0cb5a3c..b1f46deb 100644 --- a/internal/cmdutils/cmdutils_test.go +++ b/internal/cmdutils/cmdutils_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "testing" - "time" "github.com/openkcm/common-sdk/pkg/health" "github.com/stretchr/testify/assert" @@ -121,29 +120,6 @@ func TestStatusListener(t *testing.T) { }) } -func TestStartStatusServer(t *testing.T) { - t.Run("returns error when connection string creation fails", func(t *testing.T) { - cfg := &config.Config{ - Database: config.Database{ - Name: "", - Port: "", - }, - } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - err := startStatusServer(ctx, cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "making connection string from config") - }) -} - -func TestHealthStatusTimeout(t *testing.T) { - t.Run("has correct value", func(t *testing.T) { - assert.Equal(t, 5*time.Second, healthStatusTimeout) - }) -} - func ExampleCobraCommand() { businessFunc := func(ctx context.Context, cfg *config.Config) error { fmt.Println("Running business logic") diff --git a/internal/config/config.go b/internal/config/config.go index 4d5cd63a..5442c96b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,13 +3,19 @@ package config import ( + "fmt" + "reflect" "time" + "github.com/creasty/defaults" + "github.com/knadh/koanf/v2" "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" ) type Config struct { - commoncfg.BaseConfig `mapstructure:",squash" yaml:",inline"` + commoncfg.BaseConfig `yaml:",inline"` HTTP HTTPServer `yaml:"http"` GRPC GRPCServer `yaml:"grpc"` @@ -19,6 +25,36 @@ type Config struct { Migrate Migrate `yaml:"migrate"` SessionManager SessionManager `yaml:"sessionManager"` Housekeeper Housekeeper `yaml:"housekeeper"` + Trust Trust `yaml:"trust"` +} + +type Trust struct { + Mod string `yaml:"module" default:"trust.module.oidc"` + koanf *koanf.Koanf +} + +func (c *Trust) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Trust) Module() string { + return c.Mod +} + +func (c *Trust) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) +} + +func unmarshalExtension(out any, ko *koanf.Koanf) error { + if err := ko.UnmarshalWithConf("", out, koanfUnmarshalConf); err != nil { + return fmt.Errorf("unmarshaling into a structure: %w", err) + } + + setKoanf(reflect.ValueOf(out), ko) + if err := defaults.Set(out); err != nil { + return fmt.Errorf("setting defaults: %w", err) + } + return nil } type Housekeeper struct { @@ -37,17 +73,27 @@ type HTTPServer struct { } type GRPCServer struct { - commoncfg.GRPCServer `mapstructure:",squash" yaml:",inline"` + commoncfg.GRPCServer `yaml:",inline"` ShutdownTimeout time.Duration `yaml:"shutdownTimeout" default:"5s"` } type Database struct { - Name string `yaml:"name"` - Port string `yaml:"port"` - Host commoncfg.SourceRef `yaml:"host"` - User commoncfg.SourceRef `yaml:"user"` - Password commoncfg.SourceRef `yaml:"password"` + Mod string `yaml:"module" default:"database.module.pgxpool"` + + koanf *koanf.Koanf +} + +func (c *Database) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Database) Module() string { + return c.Mod +} + +func (c *Database) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) } type ValKey struct { @@ -63,15 +109,10 @@ type SessionManager struct { SessionDuration time.Duration `yaml:"sessionDuration" default:"12h"` // CallbackURL is the URL path for the OAuth2 callback endpoint, where we receive the authorization code. - CallbackURL string `yaml:"callbackURL" default:"/sm/callback"` - ClientAuth ClientAuth `yaml:"clientAuth"` - CSRFSecret commoncfg.SourceRef `yaml:"csrfSecret"` - CSRFSecretParsed []byte `yaml:"-"` - AdditionalQueryParametersAuthorize []string `yaml:"additionalQueryParametersAuthorize"` - AdditionalQueryParametersToken []string `yaml:"additionalQueryParametersToken"` - AdditionalQueryParametersIntrospect []string `yaml:"additionalQueryParametersIntrospect"` - AdditionalQueryParametersLogout []string `yaml:"additionalQueryParametersLogout"` - AdditionalAuthContextKeys []string `yaml:"additionalAuthContextKeys"` + CallbackURL string `yaml:"callbackURL" default:"/sm/callback"` + ClientAuth ClientAuth `yaml:"clientAuth"` + CSRFSecret commoncfg.SourceRef `yaml:"csrfSecret"` + CSRFSecretParsed []byte `yaml:"-"` // SessionCookieTemplate defines the template attributes for the session cookie. SessionCookieTemplate CookieTemplate `yaml:"sessionCookieTemplate"` // CSRFCookieTemplate defines the template attributes for the CSRF cookie. @@ -83,15 +124,6 @@ type SessionManager struct { // during the authorization flow and post logout. This is used to validate the redirect // URLs provided in the authorization request and post logout requests. AllowedRedirectBaseURLs []string `yaml:"allowedRedirectBaseURLs"` - - // Deprecated: use AllowedRedirectBaseURLs instead. - PostLogoutRedirectURL string `yaml:"postLogoutRedirectURL"` - // Deprecated: not used anymore. Kept for a helm issue with the migrate job. - RedirectURL string `yaml:"redirectURL" default:"/sm/redirect"` - // Deprecated: use AdditionalQueryParametersAuthorize instead. - AdditionalGetParametersAuthorize []string `yaml:"additionalGetParametersAuthorize"` - // Deprecated: use AdditionalQueryParametersToken instead. - AdditionalGetParametersToken []string `yaml:"additionalGetParametersToken"` } type CookieSameSiteValue string @@ -126,5 +158,18 @@ type ClientAuth struct { } type Migrate struct { - Source string `yaml:"source" default:"file://./sql"` + Mod string `yaml:"module" default:"trust.migration.module.oidc"` + koanf *koanf.Koanf +} + +func (c *Migrate) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Migrate) Module() string { + return c.Mod +} + +func (c *Migrate) UnmarshalExtension(into sessionmanager.Module) error { + return unmarshalExtension(into, c.koanf) } diff --git a/internal/config/connstr.go b/internal/config/connstr.go deleted file mode 100644 index 5e4f496c..00000000 --- a/internal/config/connstr.go +++ /dev/null @@ -1,27 +0,0 @@ -package config - -import ( - "fmt" - - "github.com/openkcm/common-sdk/pkg/commoncfg" -) - -func MakeConnStr(conf Database) (string, error) { - host, err := commoncfg.LoadValueFromSourceRef(conf.Host) - if err != nil { - return "", fmt.Errorf("loading db host: %w", err) - } - - user, err := commoncfg.LoadValueFromSourceRef(conf.User) - if err != nil { - return "", fmt.Errorf("loading db user: %w", err) - } - - password, err := commoncfg.LoadValueFromSourceRef(conf.Password) - if err != nil { - return "", fmt.Errorf("loading db password: %w", err) - } - - return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s", - host, user, string(password), conf.Name, conf.Port), nil -} diff --git a/internal/config/connstr_test.go b/internal/config/connstr_test.go deleted file mode 100644 index a22dec5e..00000000 --- a/internal/config/connstr_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package config - -import ( - "fmt" - "testing" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" -) - -func TestMakeConnStr(t *testing.T) { - tests := []struct { - name string - conf Database - wantConnStr string - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Make connection string", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "host=my_host user=my_user password=my_password dbname=my_db_name port=5432", - assertErr: assert.NoError, - }, - { - name: "Error - invalid host source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - { - name: "Error - invalid user source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - { - name: "Error - invalid password source", - conf: Database{ - Host: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_host", - }, - User: commoncfg.SourceRef{ - Source: "embedded", - Value: "my_user", - }, - Password: commoncfg.SourceRef{ - Source: "invalid-source", - Value: "my_password", - }, - Name: "my_db_name", - Port: "5432", - }, - wantConnStr: "", - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - connStr, err := MakeConnStr(tt.conf) - if !tt.assertErr(t, err, fmt.Sprintf("MakeConnStr() error = %v", err)) || err != nil { - return - } - - assert.Equal(t, tt.wantConnStr, connStr, "MakeConnStr() = %v", connStr) - }) - } -} diff --git a/internal/config/load.go b/internal/config/load.go new file mode 100644 index 00000000..8f4fc893 --- /dev/null +++ b/internal/config/load.go @@ -0,0 +1,106 @@ +package config + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "unicode" + + "github.com/creasty/defaults" + "github.com/go-viper/mapstructure/v2" + "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/v2" + "github.com/openkcm/common-sdk/pkg/commoncfg" +) + +type koanfSetter interface { + setKoanf(ko *koanf.Koanf) +} + +const configFile = "config.yaml" + +var koanfUnmarshalConf = koanf.UnmarshalConf{ + Tag: "yaml", + DecoderConfig: &mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.TextUnmarshallerHookFunc()), + Metadata: nil, + WeaklyTypedInput: true, + SquashTagOption: "inline", + }, +} + +func Load(buildInfo string, paths ...string) (*Config, error) { + for i, path := range paths { + paths[i] = filepath.Join(path, configFile) + } + + ko := koanf.New(".") + var loaded bool + for _, path := range paths { + if err := ko.Load(file.Provider(path), yamlParser{}); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("loading configuration from file %s: %w", path, err) + } + } else { + loaded = true + } + } + + if !loaded { + return nil, fmt.Errorf("no config file found at the paths %q: %w", strings.Join(paths, ", "), os.ErrNotExist) + } + + cfg := new(Config) + if err := ko.UnmarshalWithConf("", cfg, koanfUnmarshalConf); err != nil { + return nil, fmt.Errorf("unmarshaling configuration: %w", err) + } + + if err := defaults.Set(cfg); err != nil { + return nil, fmt.Errorf("setting defaults: %w", err) + } + + if buildInfo != "" { + if err := commoncfg.UpdateConfigVersion( + &cfg.BaseConfig, + buildInfo, + ); err != nil { + return nil, fmt.Errorf("updating the version configuration: %w", err) + } + } + + setKoanf(reflect.ValueOf(cfg), ko) + + return cfg, nil +} + +var koanfSetterType = reflect.TypeFor[koanfSetter]() + +func setKoanf(v reflect.Value, ko *koanf.Koanf) { + if v.Type().Implements(koanfSetterType) { + //nolint:forcetypeassert // Checked above + v.Interface().(koanfSetter).setKoanf(ko) + } + + elem := reflect.Indirect(v) + if elem.Kind() == reflect.Struct { + for field, val := range elem.Fields() { + name, _, _ := strings.Cut(field.Tag.Get(koanfUnmarshalConf.Tag), ",") + if name == "" { + runes := []rune(field.Name) + runes[0] = unicode.ToLower(runes[0]) + name = string(runes) + } + + if val.Kind() != reflect.Pointer { + val = val.Addr() + } + + setKoanf(val, ko.Cut(name)) + } + } +} diff --git a/internal/config/parser.go b/internal/config/parser.go new file mode 100644 index 00000000..d3d56cf3 --- /dev/null +++ b/internal/config/parser.go @@ -0,0 +1,23 @@ +package config + +import ( + "github.com/goccy/go-yaml" +) + +// yamlParser implements a yamlParser parser. +type yamlParser struct{} + +// Unmarshal parses the given YAML bytes. +func (yamlParser) Unmarshal(b []byte) (map[string]any, error) { + var out map[string]any + if err := yaml.UnmarshalWithOptions(b, &out); err != nil { + return nil, err + } + + return out, nil +} + +// Marshal marshals the given config map to YAML bytes. +func (yamlParser) Marshal(o map[string]any) ([]byte, error) { + return yaml.Marshal(o) +} diff --git a/internal/dbtest/postgrestest/postgres.go b/internal/dbtest/postgrestest/postgres.go index 5c836f59..6d6e6263 100644 --- a/internal/dbtest/postgrestest/postgres.go +++ b/internal/dbtest/postgrestest/postgres.go @@ -17,7 +17,7 @@ import ( slogctx "github.com/veqryn/slog-context" - migrations "github.com/openkcm/session-manager/sql" + "github.com/openkcm/session-manager/modules/oidctrust/migrations" ) const ( @@ -112,9 +112,9 @@ func prepareDB(ctx context.Context, dbPool *pgxpool.Pool, port network.Port) { migrateDB(ctx, port) b := new(pgx.Batch) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant1-id', false, 'url-one', '', '{}', '{}');`) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant2-id', false, 'url-two', '', '{}', '{}');`) - b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences, properties) VALUES ('tenant3-id', false, 'url-three', '', '{}', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant1-id', false, 'url-one', '', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant2-id', false, 'url-two', '', '{}');`) + b.Queue(`INSERT INTO trust (tenant_id, blocked, issuer, jwks_uri, audiences) VALUES ('tenant3-id', false, 'url-three', '', '{}');`) res := dbPool.SendBatch(ctx, b) err := res.Close() diff --git a/internal/grpc/import_test.go b/internal/grpc/import_test.go new file mode 100644 index 00000000..6c5da46f --- /dev/null +++ b/internal/grpc/import_test.go @@ -0,0 +1,12 @@ +package grpc_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/internal/grpc/oidcmapping.go b/internal/grpc/oidcmapping.go deleted file mode 100644 index 55221284..00000000 --- a/internal/grpc/oidcmapping.go +++ /dev/null @@ -1,129 +0,0 @@ -package grpc - -import ( - "context" - "errors" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" -) - -type OIDCMappingServer struct { - oidcmappingv1.UnimplementedServiceServer - - oidc *trust.Service -} - -func NewOIDCMappingServer(oidc *trust.Service) *OIDCMappingServer { - srv := &OIDCMappingServer{ - oidc: oidc, - } - - return srv -} - -func (srv *OIDCMappingServer) ApplyOIDCMapping(ctx context.Context, req *oidcmappingv1.ApplyOIDCMappingRequest) (*oidcmappingv1.ApplyOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, - "tenantId", req.GetTenantId(), - "issuer", req.GetIssuer(), - "jwksUri", req.GetJwksUri(), - "audiences", req.GetAudiences(), - "properties", req.GetProperties(), - "client_id", req.GetClientId(), - ) - slogctx.Debug(ctx, "ApplyOIDCMapping called") - - response := &oidcmappingv1.ApplyOIDCMappingResponse{} - - mapping := trust.OIDCMapping{ - IssuerURL: req.GetIssuer(), - Blocked: false, - JWKSURI: req.GetJwksUri(), - Audiences: req.GetAudiences(), - Properties: req.GetProperties(), - ClientID: req.GetClientId(), - } - err := srv.oidc.ApplyMapping(ctx, req.GetTenantId(), mapping) - if err != nil { - slogctx.Error(ctx, "Could not apply OIDC mapping", "error", err) - if errors.Is(err, serviceerr.ErrNotFound) { - msg := serviceerr.ErrNotFound.Error() - response.Message = &msg - return response, nil - } - - return nil, status.Errorf(codes.Internal, "failed to apply OIDC mapping: %v", err) - } - - response.Success = true - - return response, nil -} - -// BlockOIDCMapping blocks the OIDC mapping for the specified tenant. -// It calls the underlying service to set the mapping as blocked. -// Returns a response containing an optional error message if blocking fails. -func (srv *OIDCMappingServer) BlockOIDCMapping(ctx context.Context, req *oidcmappingv1.BlockOIDCMappingRequest) (*oidcmappingv1.BlockOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "BlockOIDCMapping called") - - resp := &oidcmappingv1.BlockOIDCMappingResponse{} - err := srv.oidc.BlockMapping(ctx, req.GetTenantId()) - if err != nil { - slogctx.Error(ctx, "Could not block OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to block OIDC mapping: "+msg) - } - resp.Success = true - return resp, nil -} - -// RemoveOIDCMapping removes the OIDC configuration for the tenant. -// It calls the underlying service to remove the mapping. -// Returns a respose containing an optional error message if removing fails. -func (srv *OIDCMappingServer) RemoveOIDCMapping(ctx context.Context, req *oidcmappingv1.RemoveOIDCMappingRequest) (*oidcmappingv1.RemoveOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "RemoveOIDCMapping called") - - resp := &oidcmappingv1.RemoveOIDCMappingResponse{} - err := srv.oidc.RemoveMapping(ctx, req.GetTenantId()) - if err != nil { - if !errors.Is(err, serviceerr.ErrNotFound) { - slogctx.Error(ctx, "Could not remove OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to remove OIDC mapping: "+msg) - } else { - slogctx.Warn(ctx, "RemoveOIDCMapping is called but the tenant does not exist", "error", err) - } - } - - resp.Success = true - return resp, nil -} - -// UnblockOIDCMapping unblocks the OIDC mapping for the specified tenant. -// It calls the underlying service to set the mapping as unblocked. -// Returns a response containing an optional error message if unblocking fails. -func (srv *OIDCMappingServer) UnblockOIDCMapping(ctx context.Context, req *oidcmappingv1.UnblockOIDCMappingRequest) (*oidcmappingv1.UnblockOIDCMappingResponse, error) { - ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) - slogctx.Debug(ctx, "UnblockOIDCMapping called") - - resp := &oidcmappingv1.UnblockOIDCMappingResponse{} - err := srv.oidc.UnblockMapping(ctx, req.GetTenantId()) - if err != nil { - slogctx.Error(ctx, "Could not unblock OIDC mapping", "error", err) - msg := err.Error() - resp.Message = &msg - return resp, status.Error(codes.Internal, "failed to unblock OIDC mapping: "+msg) - } - resp.Success = true - return resp, nil -} diff --git a/internal/grpc/oidcmapping_test.go b/internal/grpc/oidcmapping_test.go index 703422b9..f17796ed 100644 --- a/internal/grpc/oidcmapping_test.go +++ b/internal/grpc/oidcmapping_test.go @@ -9,45 +9,44 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + "github.com/openkcm/session-manager/pkg/serviceerr" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) -func TestNewOIDCMappingServer(t *testing.T) { +func TestNewTrustMappingServer(t *testing.T) { t.Run("creates server successfully", func(t *testing.T) { - repo := trustmock.NewInMemRepository() - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) assert.NotNil(t, server) }) } -func TestApplyOIDCMapping(t *testing.T) { +func TestApplyTrustMapping(t *testing.T) { ctx := t.Context() t.Run("success - creates new mapping", func(t *testing.T) { - repo := trustmock.NewInMemRepository() - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) jwksUri := "https://issuer.example.com/.well-known/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - Audiences: []string{"audience1", "audience2"}, - Properties: map[string]string{ - "prop1": "value1", - "prop2": "value2", - }, - } - - resp, err := server.ApplyOIDCMapping(ctx, req) + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: &jwksUri, + Audiences: []string{"audience1", "audience2"}, + }.Build(), + }.Build() + + resp, err := server.ApplyTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -56,26 +55,31 @@ func TestApplyOIDCMapping(t *testing.T) { }) t.Run("success - updates existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://old-issuer.example.com", - JWKSURI: "https://old-issuer.example.com/jwks.json", - Audiences: []string{"old-audience"}, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://old-issuer.example.com"), + JwksUri: new("https://old-issuer.example.com/jwks.json"), + Audiences: []string{"old-audience"}, + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) jwksUri := "https://new-issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://new-issuer.example.com", - JwksUri: &jwksUri, - Audiences: []string{"new-audience"}, - } + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://new-issuer.example.com"), + JwksUri: new(jwksUri), + Audiences: []string{"new-audience"}, + }.Build(), + }.Build() - resp, err := server.ApplyOIDCMapping(ctx, req) + resp, err := server.ApplyTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -83,20 +87,22 @@ func TestApplyOIDCMapping(t *testing.T) { }) t.Run("not found error - returns response with message", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithCreateError(serviceerr.ErrNotFound), + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(serviceerr.ErrNotFound), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) jwksUri := "https://issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - } + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() - resp, err := server.ApplyOIDCMapping(ctx, req) + resp, err := server.ApplyTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -107,20 +113,22 @@ func TestApplyOIDCMapping(t *testing.T) { t.Run("internal error - returns grpc error", func(t *testing.T) { internalErr := errors.New("database connection failed") - repo := trustmock.NewInMemRepository( - trustmock.WithCreateError(internalErr), + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(internalErr), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) jwksUri := "https://issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://issuer.example.com", - JwksUri: &jwksUri, - } + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() - resp, err := server.ApplyOIDCMapping(ctx, req) + resp, err := server.ApplyTrustMapping(ctx, req) assert.Nil(t, resp) require.Error(t, err) @@ -128,29 +136,34 @@ func TestApplyOIDCMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to apply OIDC mapping") + assert.Contains(t, st.Message(), "failed to apply Trust mapping") }) t.Run("update error - returns grpc error", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - } + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() updateErr := errors.New("update failed") - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - trustmock.WithUpdateError(updateErr), + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), + mocktrust.WithUpdateError(updateErr), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) jwksUri := "https://new-issuer.example.com/jwks.json" - req := &oidcmappingv1.ApplyOIDCMappingRequest{ - TenantId: "tenant-123", - Issuer: "https://new-issuer.example.com", - JwksUri: &jwksUri, - } + req := trustmappingv1.ApplyTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ + Issuer: new("https://new-issuer.example.com"), + JwksUri: new(jwksUri), + }.Build(), + }.Build() - resp, err := server.ApplyOIDCMapping(ctx, req) + resp, err := server.ApplyTrustMapping(ctx, req) assert.Nil(t, resp) require.Error(t, err) @@ -161,25 +174,28 @@ func TestApplyOIDCMapping(t *testing.T) { }) } -func TestBlockOIDCMapping(t *testing.T) { +func TestBlockTrustMapping(t *testing.T) { ctx := t.Context() t.Run("success - blocks existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.BlockOIDCMapping(ctx, req) + resp, err := server.BlockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -188,21 +204,24 @@ func TestBlockOIDCMapping(t *testing.T) { }) t.Run("success - already blocked", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.BlockOIDCMapping(ctx, req) + resp, err := server.BlockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -210,17 +229,17 @@ func TestBlockOIDCMapping(t *testing.T) { }) t.Run("not found - returns success", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(serviceerr.ErrNotFound), + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(serviceerr.ErrNotFound), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.BlockOIDCMapping(ctx, req) + resp, err := server.BlockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -229,17 +248,17 @@ func TestBlockOIDCMapping(t *testing.T) { t.Run("error - returns grpc error with message", func(t *testing.T) { internalErr := errors.New("database error") - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(internalErr), + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(internalErr), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.BlockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.BlockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.BlockOIDCMapping(ctx, req) + resp, err := server.BlockTrustMapping(ctx, req) require.Error(t, err) assert.NotNil(t, resp) @@ -249,28 +268,31 @@ func TestBlockOIDCMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to block OIDC mapping") + assert.Contains(t, st.Message(), "failed to block Trust mapping") }) } -func TestRemoveOIDCMapping(t *testing.T) { +func TestRemoveTrustMapping(t *testing.T) { ctx := t.Context() t.Run("success - removes existing mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.RemoveOIDCMapping(ctx, req) + resp, err := server.RemoveTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -280,17 +302,17 @@ func TestRemoveOIDCMapping(t *testing.T) { t.Run("error - returns grpc error with message", func(t *testing.T) { deleteErr := errors.New("delete failed") - repo := trustmock.NewInMemRepository( - trustmock.WithDeleteError(deleteErr), + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(deleteErr), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.RemoveOIDCMapping(ctx, req) + resp, err := server.RemoveTrustMapping(ctx, req) require.Error(t, err) assert.NotNil(t, resp) @@ -300,21 +322,21 @@ func TestRemoveOIDCMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to remove OIDC mapping") + assert.Contains(t, st.Message(), "failed to remove Trust mapping") }) t.Run("error - delete is indempotent", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithDeleteError(serviceerr.ErrNotFound), + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(serviceerr.ErrNotFound), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.RemoveOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.RemoveTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.RemoveOIDCMapping(ctx, req) + resp, err := server.RemoveTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -323,25 +345,28 @@ func TestRemoveOIDCMapping(t *testing.T) { }) } -func TestUnblockOIDCMapping(t *testing.T) { +func TestUnblockTrustMapping(t *testing.T) { ctx := t.Context() t.Run("success - unblocks blocked mapping", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.UnblockOIDCMapping(ctx, req) + resp, err := server.UnblockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -350,21 +375,24 @@ func TestUnblockOIDCMapping(t *testing.T) { }) t.Run("success - already unblocked", func(t *testing.T) { - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.UnblockOIDCMapping(ctx, req) + resp, err := server.UnblockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -372,17 +400,17 @@ func TestUnblockOIDCMapping(t *testing.T) { }) t.Run("not found - returns success", func(t *testing.T) { - repo := trustmock.NewInMemRepository( - trustmock.WithGetError(serviceerr.ErrNotFound), + repo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(serviceerr.ErrNotFound), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.UnblockOIDCMapping(ctx, req) + resp, err := server.UnblockTrustMapping(ctx, req) require.NoError(t, err) assert.NotNil(t, resp) @@ -391,22 +419,25 @@ func TestUnblockOIDCMapping(t *testing.T) { t.Run("error - returns grpc error with message", func(t *testing.T) { internalErr := errors.New("update failed") - existingMapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, - } - repo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", existingMapping), - trustmock.WithUpdateError(internalErr), + existingMapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existingMapping), + mocktrust.WithUpdateError(internalErr), ) - svc := trust.NewService(repo) - server := grpc.NewOIDCMappingServer(svc) + svc := newTrust(repo) + server := grpc.NewTrustMappingServer(svc) - req := &oidcmappingv1.UnblockOIDCMappingRequest{ - TenantId: "tenant-123", - } + req := trustmappingv1.UnblockTrustMappingRequest_builder{ + TenantId: new("tenant-123"), + }.Build() - resp, err := server.UnblockOIDCMapping(ctx, req) + resp, err := server.UnblockTrustMapping(ctx, req) require.Error(t, err) assert.NotNil(t, resp) @@ -416,6 +447,6 @@ func TestUnblockOIDCMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to unblock OIDC mapping") + assert.Contains(t, st.Message(), "failed to unblock Trust mapping") }) } diff --git a/internal/grpc/session.go b/internal/grpc/session.go index ab226a3f..10728e6b 100644 --- a/internal/grpc/session.go +++ b/internal/grpc/session.go @@ -17,14 +17,15 @@ import ( rpcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/rpc/v1" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" typesv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/types/v1" slogctx "github.com/veqryn/slog-context" grpccodes "google.golang.org/grpc/codes" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/debugtools" "github.com/openkcm/session-manager/internal/session" - "github.com/openkcm/session-manager/internal/trust" ) const defaultIntrospectionCacheExpiration = 30 * time.Second @@ -35,7 +36,7 @@ type SessionServer struct { sessionv1.UnimplementedServiceServer sessionRepo session.Repository - trustRepo trust.OIDCMappingRepository + trust sessionmanager.Trust newCreds credentials.Builder queryParametersIntrospect []string @@ -50,14 +51,14 @@ type SessionServer struct { func NewSessionServer( ctx context.Context, sessionRepo session.Repository, - trustRepo trust.OIDCMappingRepository, + trust sessionmanager.Trust, idleSessionTimeout time.Duration, clientID string, opts ...SessionServerOption, ) *SessionServer { s := &SessionServer{ sessionRepo: sessionRepo, - trustRepo: trustRepo, + trust: trust, idleSessionTimeout: idleSessionTimeout, newCreds: func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, clientID: clientID, @@ -108,15 +109,15 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio return &sessionv1.GetSessionResponse{Valid: false}, nil } - // Get trust mapping for the given tenant ID - mapping, err := s.trustRepo.Get(ctx, req.GetTenantId()) + // Get trust for the given tenant ID + trust, err := s.trust.Get(ctx, req.GetTenantId()) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "failed to get an oidc mapping") slogctx.Warn(ctx, "Is this an attack? Could not get trust mapping", "issuer", sess.Issuer, "error", err) return &sessionv1.GetSessionResponse{Valid: false}, nil } - if mapping.Blocked { + if trust.GetBlocked() { slogctx.Warn(ctx, "Tenant is blocked", "issuer", sess.Issuer) span.SetStatus(codes.Ok, "the tenant is blocked") st := status.New(grpccodes.FailedPrecondition, "the tenant is blocked") @@ -163,7 +164,7 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio } // Introspect access token - result, err := s.introspectToken(ctx, sess.AccessToken, &mapping) + result, err := s.introspectToken(ctx, sess.AccessToken, trust.GetOidc()) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "failed to introspect an access token") @@ -198,33 +199,35 @@ func (s *SessionServer) GetOIDCProvider(ctx context.Context, req *sessionv1.GetO ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_provider") defer span.End() - provider, err := s.trustRepo.Get(ctx, req.GetTenantId()) + provider, err := s.trust.Get(ctx, req.GetTenantId()) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "failed to get an oidc provider") return nil, fmt.Errorf("getting odic provider: %w", err) } + oidc := provider.GetOidc() + span.SetStatus(codes.Ok, "") return &sessionv1.GetOIDCProviderResponse{ Provider: &typesv1.OIDCProvider{ - IssuerUrl: provider.IssuerURL, - JwksUri: provider.JWKSURI, - Audiences: provider.Audiences, + IssuerUrl: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), }, }, nil } -func (s *SessionServer) getClientID(mapping *trust.OIDCMapping) string { - if mapping.ClientID != "" { - return mapping.ClientID +func (s *SessionServer) getClientID(oidcTrust *oidcv1.OIDC) string { + if clientID := oidcTrust.GetClientId(); clientID != "" { + return clientID } return s.clientID } -func (s *SessionServer) httpClient(mapping *trust.OIDCMapping) *http.Client { - creds := s.newCreds(s.getClientID(mapping)) +func (s *SessionServer) httpClient(oidcTrust *oidcv1.OIDC) *http.Client { + creds := s.newCreds(s.getClientID(oidcTrust)) transport := creds.Transport() if debugSettingSMDumpTransport.Value() == "1" { transport = debugtools.NewTransport(transport) @@ -235,7 +238,7 @@ func (s *SessionServer) httpClient(mapping *trust.OIDCMapping) *http.Client { } } -func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcTrust *trust.OIDCMapping) (oidc.Introspection, error) { +func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcTrust *oidcv1.OIDC) (oidc.Introspection, error) { // first check the cache for a recent introspection result for this token hashedSuffix := sha256.Sum256([]byte(token)) cacheKey := base64.RawURLEncoding.EncodeToString(hashedSuffix[:]) @@ -246,13 +249,12 @@ func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcT httpClient := s.httpClient(oidcTrust) // create the provider for the given issuer - provider, err := oidc.NewProvider(oidcTrust.IssuerURL, oidcTrust.Audiences, - oidc.WithIntrospectQueryParameters(oidcTrust.GetIntrospectParameters(s.queryParametersIntrospect)), + provider, err := oidc.NewProvider(oidcTrust.GetIssuer(), oidcTrust.GetAudiences(), oidc.WithAllowHttpScheme(s.allowHttpScheme), oidc.WithSecureHTTPClient(httpClient), ) if err != nil { - slogctx.Error(ctx, "Could not create OpenID provider", "issuer", oidcTrust.IssuerURL, "error", err) + slogctx.Error(ctx, "Could not create OpenID provider", "issuer", oidcTrust.GetIssuer(), "error", err) return oidc.Introspection{Active: false}, err } @@ -263,7 +265,7 @@ func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcT slogctx.Debug(ctx, "No introspection endpoint configured", "issuer", provider.Issuer) return oidc.Introspection{Active: true}, nil } - slogctx.Error(ctx, "Could not introspect access token", "error", err) + slogctx.Error(ctx, "Could not introspect token", "error", err) return oidc.Introspection{Active: false}, err } diff --git a/internal/grpc/session_test.go b/internal/grpc/session_test.go index 2dda10fe..69335f25 100644 --- a/internal/grpc/session_test.go +++ b/internal/grpc/session_test.go @@ -16,34 +16,37 @@ import ( rpcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/rpc/v1" sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" "github.com/openkcm/session-manager/internal/grpc" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) func TestNewSessionServer(t *testing.T) { ctx := t.Context() t.Run("creates server successfully", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, idleSessionTimeout, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, idleSessionTimeout, "") assert.NotNil(t, server) }) t.Run("creates server with options", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute server := grpc.NewSessionServer(ctx, sessionRepo, - trustRepo, + trust, idleSessionTimeout, "", grpc.WithQueryParametersIntrospect([]string{"param1", "param2"}), @@ -54,12 +57,13 @@ func TestNewSessionServer(t *testing.T) { t.Run("handles nil option gracefully", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute server := grpc.NewSessionServer(ctx, sessionRepo, - trustRepo, + trust, idleSessionTimeout, "", nil, @@ -108,10 +112,13 @@ func TestGetSession(t *testing.T) { AuthContext: map[string]string{"key": "value"}, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), @@ -119,14 +126,11 @@ func TestGetSession(t *testing.T) { // Mark session as active _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) - req := &sessionv1.GetSessionRequest{ SessionId: "session-123", TenantId: "tenant-123", @@ -180,21 +184,22 @@ func TestGetSession(t *testing.T) { }, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) @@ -236,21 +241,22 @@ func TestGetSession(t *testing.T) { }, } - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithSession(sess), ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) @@ -272,9 +278,9 @@ func TestGetSession(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository( sessionmock.WithIsActiveError(isActiveErr), ) - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-123", @@ -301,9 +307,9 @@ func TestGetSession(t *testing.T) { ) // Don't bump active - session is not active - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-789", @@ -328,9 +334,9 @@ func TestGetSession(t *testing.T) { assert.NoError(t, err) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fail", @@ -359,9 +365,9 @@ func TestGetSession(t *testing.T) { _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) // No mapping added to repo - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-no-provider", @@ -389,15 +395,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: true, // Mapping is blocked - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-blocked", @@ -437,15 +445,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fingerprint", @@ -473,15 +483,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust("wrong-tenant", mapping), - ) + mapping := trustv1.Trust_builder{ + TenantId: new("wrong-tenant"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-tenant", @@ -509,15 +521,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: "https://invalid-issuer-no-server.example.com", - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://invalid-issuer-no-server.example.com"), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-config-fail", @@ -561,15 +575,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) @@ -617,15 +633,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) @@ -666,15 +684,17 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trust.OIDCMapping{ - IssuerURL: testServer.URL, - Blocked: false, - } - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust(sess.TenantID, mapping), - ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "", + mapping := trustv1.Trust_builder{ + TenantId: new(sess.TenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(testServer.URL), + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) + + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), ) @@ -702,15 +722,10 @@ func TestWithQueryParametersIntrospect(t *testing.T) { // Test that the option actually sets the parameters sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, - sessionRepo, - trustRepo, - 90*time.Minute, - "", - opt, - ) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", opt) assert.NotNil(t, server) }) @@ -720,18 +735,20 @@ func TestGetOIDCProvider(t *testing.T) { ctx := t.Context() t.Run("success - returns OIDC provider", func(t *testing.T) { - mapping := trust.OIDCMapping{ - IssuerURL: "https://issuer.example.com", - JWKSURI: "https://issuer.example.com/.well-known/jwks.json", - Audiences: []string{"audience1", "audience2"}, - } - + mapping := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + JwksUri: new("https://issuer.example.com/.well-known/jwks.json"), + Audiences: []string{"audience1", "audience2"}, + }.Build(), + }.Build() + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(trustRepo) sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository( - trustmock.WithTrust("tenant-123", mapping), - ) - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", @@ -749,10 +766,9 @@ func TestGetOIDCProvider(t *testing.T) { t.Run("error - provider not found", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository() - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") - + trustRepo := mocktrust.NewInMemRepository() + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "non-existent-tenant", } @@ -766,12 +782,11 @@ func TestGetOIDCProvider(t *testing.T) { t.Run("error - repository returns error", func(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() - trustRepo := trustmock.NewInMemRepository( - trustmock.WithGetError(errors.New("database connection error")), + trustRepo := mocktrust.NewInMemRepository( + mocktrust.WithGetError(errors.New("database connection error")), ) - - server := grpc.NewSessionServer(ctx, sessionRepo, trustRepo, 90*time.Minute, "") - + trust := newTrust(trustRepo) + server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", } diff --git a/internal/grpc/trustmapping.go b/internal/grpc/trustmapping.go new file mode 100644 index 00000000..54b4a936 --- /dev/null +++ b/internal/grpc/trustmapping.go @@ -0,0 +1,139 @@ +package grpc + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +type TrustMappingServer struct { + trustmappingv1.UnimplementedServiceServer + + trust sessionmanager.Trust +} + +func NewTrustMappingServer(trust sessionmanager.Trust) *TrustMappingServer { + srv := &TrustMappingServer{ + trust: trust, + } + + return srv +} + +func (srv *TrustMappingServer) ApplyTrustMapping(ctx context.Context, in *trustmappingv1.ApplyTrustMappingRequest) (*trustmappingv1.ApplyTrustMappingResponse, error) { + oidcIn := in.GetOidc() + oidc := oidcv1.OIDC_builder{ + TenantId: new(oidcIn.GetTenantId()), + Issuer: new(oidcIn.GetIssuer()), + JwksUri: new(oidcIn.GetJwksUri()), + Audiences: oidcIn.GetAudiences(), + ClientId: new(oidcIn.GetClientId()), + }.Build() + + trust := trustv1.Trust_builder{ + TenantId: new(in.GetTenantId()), + Oidc: oidc, + }.Build() + + ctx = slogctx.With(ctx, + "tenantId", trust.GetTenantId(), + "issuer", oidc.GetIssuer(), + "jwksUri", oidc.GetJwksUri(), + "audiences", oidc.GetAudiences(), + "client_id", oidc.GetClientId(), + ) + + slogctx.Debug(ctx, "ApplyTrustMapping called") + + response := trustmappingv1.ApplyTrustMappingResponse_builder{}.Build() + + if err := srv.trust.ApplyMapping(ctx, trust); err != nil { + slogctx.Error(ctx, "Could not apply Trust mapping", "error", err) + if errors.Is(err, serviceerr.ErrNotFound) { + msg := serviceerr.ErrNotFound.Error() + response.SetMessage(msg) + return response, nil + } + + return nil, status.Errorf(codes.Internal, "failed to apply Trust mapping: %v", err) + } + + response.SetSuccess(true) + + return response, nil +} + +// BlockTrustMapping blocks the Trust mapping for the specified tenant. +// It calls the underlying service to set the mapping as blocked. +// Returns a response containing an optional error message if blocking fails. +func (srv *TrustMappingServer) BlockTrustMapping(ctx context.Context, req *trustmappingv1.BlockTrustMappingRequest) (*trustmappingv1.BlockTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "BlockTrustMapping called") + + resp := trustmappingv1.BlockTrustMappingResponse_builder{}.Build() + err := srv.trust.BlockMapping(ctx, req.GetTenantId()) + if err != nil { + slogctx.Error(ctx, "Could not block Trust mapping", "error", err) + msg := err.Error() + + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to block Trust mapping: "+msg) + } + + resp.SetSuccess(true) + return resp, nil +} + +// RemoveTrustMapping removes the Trust configuration for the tenant. +// It calls the underlying service to remove the mapping. +// Returns a respose containing an optional error message if removing fails. +func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trustmappingv1.RemoveTrustMappingRequest) (*trustmappingv1.RemoveTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "RemoveTrustMapping called") + + resp := &trustmappingv1.RemoveTrustMappingResponse{} + err := srv.trust.RemoveMapping(ctx, req.GetTenantId()) + if err != nil { + if !errors.Is(err, serviceerr.ErrNotFound) { + slogctx.Error(ctx, "Could not remove Trust mapping", "error", err) + msg := err.Error() + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to remove Trust mapping: "+msg) + } else { + slogctx.Warn(ctx, "RemoveTrustMapping is called but the tenant does not exist", "error", err) + } + } + + resp.SetSuccess(true) + return resp, nil +} + +// UnblockTrustMapping unblocks the Trust mapping for the specified tenant. +// It calls the underlying service to set the mapping as unblocked. +// Returns a response containing an optional error message if unblocking fails. +func (srv *TrustMappingServer) UnblockTrustMapping(ctx context.Context, req *trustmappingv1.UnblockTrustMappingRequest) (*trustmappingv1.UnblockTrustMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "UnblockTrustMapping called") + + resp := &trustmappingv1.UnblockTrustMappingResponse{} + err := srv.trust.UnblockMapping(ctx, req.GetTenantId()) + if err != nil { + slogctx.Error(ctx, "Could not unblock Trust mapping", "error", err) + msg := err.Error() + resp.SetMessage(msg) + return resp, status.Error(codes.Internal, "failed to unblock Trust mapping: "+msg) + } + + resp.SetSuccess(true) + return resp, nil +} diff --git a/internal/session/housekeeper.go b/internal/session/housekeeper.go index e70cc9c9..1d0928a1 100644 --- a/internal/session/housekeeper.go +++ b/internal/session/housekeeper.go @@ -96,12 +96,14 @@ func (m *Manager) housekeepSession(ctx context.Context, s Session, refreshTrigge // refreshAccessToken refreshes the access token for the given session using its refresh token. func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { - mapping, err := m.trustRepo.Get(ctx, s.TenantID) + trust, err := m.trust.Get(ctx, s.TenantID) if err != nil { return fmt.Errorf("could not get trust mapping: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return fmt.Errorf("could not get OpenID configuration: %w", err) } @@ -109,12 +111,6 @@ func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { data := url.Values{} data.Set("grant_type", "refresh_token") data.Set("refresh_token", s.RefreshToken) - for _, parameter := range m.queryParametersToken { - value, ok := mapping.Properties[parameter] - if ok { - data.Set(parameter, value) - } - } req, err := http.NewRequestWithContext(ctx, http.MethodPost, openidConf.TokenEndpoint, bytes.NewBufferString(data.Encode())) if err != nil { @@ -122,7 +118,7 @@ func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := m.httpClient(mapping) + client := m.httpClient(oidc) resp, err := client.Do(req) if err != nil { return err diff --git a/internal/session/housekeeper_test.go b/internal/session/housekeeper_test.go index e466a63b..4367d445 100644 --- a/internal/session/housekeeper_test.go +++ b/internal/session/housekeeper_test.go @@ -2,6 +2,7 @@ package session_test import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -10,12 +11,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) func TestDeleteIdleSessions(t *testing.T) { @@ -92,7 +95,6 @@ func TestRefreshAccessToken(t *testing.T) { assert.Equal(t, "refresh_token", r.Form.Get("grant_type")) assert.Equal(t, "old-refresh-token", r.Form.Get("refresh_token")) assert.Equal(t, "test-client-id", r.Form.Get("client_id")) - assert.Equal(t, "param-value", r.Form.Get("test-param")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ @@ -104,14 +106,15 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServerURL, - Properties: map[string]string{ - "test-param": "param-value", - }, - } + mapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(oidcRepo) sess := session.Session{ ID: sessionID, @@ -130,13 +133,12 @@ func TestRefreshAccessToken(t *testing.T) { ClientAuth: config.ClientAuth{ ClientID: "test-client-id", }, - AdditionalQueryParametersToken: []string{"test-param"}, - CSRFSecretParsed: []byte(testCSRFSecret), + CSRFSecretParsed: []byte(testCSRFSecret), } manager, err := session.NewManager(ctx, cfg, - oidcRepo, + trust, sessions, nil, session.WithAllowHttpScheme(true), @@ -155,7 +157,8 @@ func TestRefreshAccessToken(t *testing.T) { }) t.Run("Error - trust mapping not found", func(t *testing.T) { - oidcRepo := trustmock.NewInMemRepository() + oidcRepo := mocktrust.NewInMemRepository() + trust := newTrust(oidcRepo) sess := session.Session{ ID: sessionID, @@ -176,7 +179,7 @@ func TestRefreshAccessToken(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - manager, err := session.NewManager(ctx, cfg, oidcRepo, sessions, nil) + manager, err := session.NewManager(ctx, cfg, trust, sessions, nil) require.NoError(t, err) // Trigger housekeeping - should log error but not fail @@ -205,12 +208,14 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServerURL, - Properties: map[string]string{}, - } + mapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) sess := session.Session{ ID: sessionID, @@ -234,7 +239,7 @@ func TestRefreshAccessToken(t *testing.T) { manager, err := session.NewManager(ctx, cfg, - oidcRepo, + newTrust(oidcRepo), sessions, nil, session.WithAllowHttpScheme(true), @@ -263,12 +268,14 @@ func TestRefreshAccessToken(t *testing.T) { defer discoveryServer.Close() discoveryServerURL = discoveryServer.URL - mapping := trust.OIDCMapping{ - IssuerURL: discoveryServer.URL, - Properties: map[string]string{}, // Missing required parameter - } + mapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServer.URL), + }.Build(), + }.Build() - oidcRepo := trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) sess := session.Session{ ID: sessionID, @@ -286,13 +293,12 @@ func TestRefreshAccessToken(t *testing.T) { ClientAuth: config.ClientAuth{ ClientID: "test-client-id", }, - AdditionalQueryParametersToken: []string{"missing-param"}, - CSRFSecretParsed: []byte(testCSRFSecret), + CSRFSecretParsed: []byte(testCSRFSecret), } manager, err := session.NewManager(ctx, cfg, - oidcRepo, + newTrust(oidcRepo), sessions, nil, session.WithAllowHttpScheme(true), @@ -367,4 +373,135 @@ func TestHousekeepSession_ErrorCases(t *testing.T) { require.NoError(t, err) assert.Equal(t, "original-token", updatedSess.AccessToken) }) + + t.Run("IsActive returns error — session is skipped, housekeeping does not fail", func(t *testing.T) { + sessionID := "test-session-id-isactive-err" + + sess := session.Session{ + ID: sessionID, + TenantID: "test-tenant", + AccessTokenExpiry: time.Now().Add(2 * time.Hour), + Expiry: time.Now().Add(2 * time.Hour), + } + + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithIsActiveError(errors.New("valkey unavailable")), + ) + + cfg := &config.SessionManager{ + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, nil, sessions, nil) + require.NoError(t, err) + + // Housekeeping must not surface the IsActive error. + err = manager.TriggerHousekeeping(ctx, 1, time.Hour) + require.NoError(t, err) + + // Session must still exist — no deletion was attempted. + _, err = sessions.LoadSession(ctx, sessionID) + require.NoError(t, err) + }) + + t.Run("DeleteSession returns error — housekeeping continues without failing", func(t *testing.T) { + sessionID := "test-session-id-delete-err" + + sess := session.Session{ + ID: sessionID, + TenantID: "test-tenant", + AccessTokenExpiry: time.Now().Add(2 * time.Hour), + Expiry: time.Now().Add(2 * time.Hour), + } + + // Inactive session (no BumpActive) + DeleteSession always errors. + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithDeleteSessionError(errors.New("delete failed")), + ) + + cfg := &config.SessionManager{ + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, nil, sessions, nil) + require.NoError(t, err) + + // Housekeeping must not surface the DeleteSession error. + err = manager.TriggerHousekeeping(ctx, 1, time.Hour) + require.NoError(t, err) + }) + + t.Run("StoreSession error during token refresh — housekeeping continues without failing", func(t *testing.T) { + tenantID := "test-tenant-store-err" + sessionID := "test-session-id-store-err" + + var discoveryServerURL string + discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": discoveryServerURL, + "token_endpoint": discoveryServerURL + "/token", + }) + })) + defer discoveryServer.Close() + discoveryServerURL = discoveryServer.URL + + // Token server returns a valid refresh response. + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + }) + })) + defer tokenServer.Close() + + mapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(discoveryServerURL), + }.Build(), + }.Build() + + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trust := newTrust(oidcRepo) + + sess := session.Session{ + ID: sessionID, + TenantID: tenantID, + RefreshToken: "old-refresh-token", + AccessToken: "old-access-token", + AccessTokenExpiry: time.Now().Add(10 * time.Second), // Near expiry. + Expiry: time.Now().Add(1 * time.Hour), + } + + sessions := sessionmock.NewInMemRepository( + sessionmock.WithSession(sess), + sessionmock.WithStoreSessionError(errors.New("store failed")), + ) + err := sessions.BumpActive(ctx, sessionID, time.Hour) + require.NoError(t, err) + + cfg := &config.SessionManager{ + ClientAuth: config.ClientAuth{ClientID: "client-id"}, + CSRFSecretParsed: []byte(testCSRFSecret), + } + + manager, err := session.NewManager(ctx, cfg, trust, sessions, nil, + session.WithAllowHttpScheme(true), + ) + require.NoError(t, err) + + // Housekeeping must not surface the StoreSession error. + err = manager.TriggerHousekeeping(ctx, 1, 1*time.Minute) + require.NoError(t, err) + + // Access token must remain unchanged — store failed. + updatedSess, err := sessions.LoadSession(ctx, sessionID) + require.NoError(t, err) + assert.Equal(t, "old-access-token", updatedSess.AccessToken) + }) } diff --git a/internal/session/import_test.go b/internal/session/import_test.go new file mode 100644 index 00000000..cfe5ca6f --- /dev/null +++ b/internal/session/import_test.go @@ -0,0 +1,12 @@ +package session_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/internal/session/manager.go b/internal/session/manager.go index 12c2e124..17d5eb60 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -19,16 +19,19 @@ import ( "github.com/jellydator/ttlcache/v3" "github.com/openkcm/common-sdk/pkg/csrf" "github.com/openkcm/common-sdk/pkg/oidc" + "google.golang.org/protobuf/proto" + flowv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/flow/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" + sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/debugtools" "github.com/openkcm/session-manager/internal/pkce" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" + "github.com/openkcm/session-manager/pkg/serviceerr" ) const defaultWKOCCacheExpiration = 30 * time.Minute @@ -40,20 +43,16 @@ const ( ) type Manager struct { - trustRepo trust.OIDCMappingRepository - sessions Repository - pkce pkce.Source - audit *otlpaudit.AuditLogger - newCreds credentials.Builder - - sessionDuration time.Duration - idleSessionTimeout time.Duration - callbackURL *url.URL - clientID string - queryParametersAuth []string - queryParametersToken []string - authContextKeys []string - queryParametersLogout []string + trust sessionmanager.Trust + sessions Repository + pkce pkce.Source + audit *otlpaudit.AuditLogger + newCreds credentials.Builder + + sessionDuration time.Duration + idleSessionTimeout time.Duration + callbackURL *url.URL + clientID string sessionCookieTemplate config.CookieTemplate csrfCookieTemplate config.CookieTemplate @@ -70,7 +69,7 @@ type Manager struct { func NewManager( ctx context.Context, cfg *config.SessionManager, - trustRepo trust.OIDCMappingRepository, + trust sessionmanager.Trust, sessionsRepo Repository, auditLogger *otlpaudit.AuditLogger, opts ...ManagerOption, @@ -81,15 +80,11 @@ func NewManager( } m := &Manager{ - trustRepo: trustRepo, + trust: trust, sessions: sessionsRepo, audit: auditLogger, sessionDuration: cfg.SessionDuration, idleSessionTimeout: cfg.IdleSessionTimeout, - queryParametersAuth: cfg.AdditionalQueryParametersAuthorize, - queryParametersToken: cfg.AdditionalQueryParametersToken, - authContextKeys: cfg.AdditionalAuthContextKeys, - queryParametersLogout: cfg.AdditionalQueryParametersLogout, sessionCookieTemplate: cfg.SessionCookieTemplate, csrfCookieTemplate: cfg.CSRFCookieTemplate, loginCSRFCookieTemplate: cfg.LoginCSRFCookieTemplate, @@ -117,12 +112,14 @@ func NewManager( // MakeAuthURI returns an OIDC authentication URI. func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, requestURI string) (string, string, error) { - mapping, err := m.trustRepo.Get(ctx, tenantID) + trust, err := m.trust.Get(ctx, tenantID) if err != nil { return "", "", fmt.Errorf("getting trust mapping: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return "", "", fmt.Errorf("getting an openid config: %w", err) } @@ -146,7 +143,7 @@ func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, reques return "", "", fmt.Errorf("storing session: %w", err) } - u, err := m.authURI(openidConf, state, pkce, mapping) + u, err := m.authURI(openidConf, state, pkce, oidc) if err != nil { return "", "", fmt.Errorf("generating auth uri: %w", err) } @@ -158,7 +155,7 @@ func (m *Manager) LoadState(ctx context.Context, stateID string) (State, error) return m.sessions.LoadState(ctx, stateID) } -func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce.PKCE, mapping trust.OIDCMapping) (string, error) { +func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce.PKCE, oidc *oidcv1.OIDC) (string, error) { u, err := url.Parse(openidConf.AuthorizationEndpoint) if err != nil { return "", fmt.Errorf("parsing authorisation endpoint url: %w", err) @@ -167,16 +164,15 @@ func (m *Manager) authURI(openidConf *oidc.Configuration, state State, pkce pkce q := u.Query() q.Set("scope", "openid profile email groups") q.Set("response_type", "code") - q.Set("client_id", m.getClientID(mapping)) + q.Set("client_id", m.getClientID(oidc)) q.Set("state", state.ID) q.Set("code_challenge", pkce.Challenge) q.Set("code_challenge_method", pkce.Method) q.Set("redirect_uri", m.callbackURL.String()) - for _, parameter := range m.queryParametersAuth { - value, ok := mapping.Properties[parameter] - if ok { - q.Set(parameter, value) - } + + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_AuthAttributes).([]*flowv1.Attribute) { + q.Set(param.GetKey(), param.GetValue()) } u.RawQuery = q.Encode() @@ -230,19 +226,21 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr return OIDCSessionData{}, serviceerr.ErrFingerprintMismatch } - mapping, err := m.trustRepo.Get(ctx, state.TenantID) + trust, err := m.trust.Get(ctx, state.TenantID) if err != nil { m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get trust mapping") return OIDCSessionData{}, fmt.Errorf("getting trust mapping: %w", err) } - openidConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + openidConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get openid configuration") return OIDCSessionData{}, fmt.Errorf("getting openid configuration: %w", err) } - tokens, err := m.exchangeCode(ctx, openidConf, code, state.PKCEVerifier, mapping) + tokens, err := m.exchangeCode(ctx, openidConf, code, state.PKCEVerifier, oidc) if err != nil { m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to exchange code for tokens") return OIDCSessionData{}, fmt.Errorf("exchanging code for tokens: %w", err) @@ -299,14 +297,13 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr // prepare the auth context used by ExtAuthZ authContext := map[string]string{ - "issuer": mapping.IssuerURL, - "client_id": m.getClientID(mapping), + "issuer": oidc.GetIssuer(), + "client_id": m.getClientID(oidc), } - for _, parameter := range m.authContextKeys { - value, ok := mapping.Properties[parameter] - if ok { - authContext[parameter] = value - } + + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_AuthContext).([]*flowv1.Attribute) { + authContext[param.GetKey()] = param.GetValue() } session := Session{ @@ -315,7 +312,7 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr ProviderID: customClaims.SID, Fingerprint: fingerprint, CSRFToken: csrfToken, - Issuer: mapping.IssuerURL, + Issuer: oidc.GetIssuer(), Claims: Claims{ Subject: standardClaims.Subject, UserUUID: customClaims.UserUUID, @@ -376,15 +373,16 @@ func (m *Manager) Logout(ctx context.Context, sessionID, postLogoutRedirectURL s ctx = slogctx.With(ctx, "tenantId", session.TenantID) - mapping, err := m.trustRepo.Get(ctx, session.TenantID) + trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { slogctx.Error(ctx, "failed to get trust mapping for a tenant", "error", err) return "", fmt.Errorf("getting trust mapping: %w", err) } - ctx = slogctx.With(ctx, "issuerUrl", mapping.IssuerURL) + oidc := trust.GetOidc() + ctx = slogctx.With(ctx, "issuer", oidc.GetIssuer()) - oidcConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidcConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { slogctx.Warn(ctx, "failed to get oidc configuration", "error", err) return "", fmt.Errorf("getting oidc configuration: %w", err) @@ -407,14 +405,12 @@ func (m *Manager) Logout(ctx context.Context, sessionID, postLogoutRedirectURL s } vals := make(url.Values, 2) - vals.Set("client_id", m.getClientID(mapping)) + vals.Set("client_id", m.getClientID(oidc)) vals.Set("post_logout_redirect_uri", postLogoutRedirectURL) - for _, parameter := range m.queryParametersLogout { - value, ok := mapping.Properties[parameter] - if ok { - vals.Set(parameter, value) - } + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_LogoutAttributes).([]*flowv1.Attribute) { + vals.Set(param.GetKey(), param.GetValue()) } redirectURL.RawQuery = vals.Encode() @@ -473,12 +469,14 @@ func (m *Manager) BCLogout(ctx context.Context, logoutJWT string) error { return nil } - mapping, err := m.trustRepo.Get(ctx, session.TenantID) + trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { return fmt.Errorf("getting trust mapping: %w", err) } - oidcConf, err := m.getOpenIDConfig(ctx, mapping.IssuerURL) + oidc := trust.GetOidc() + + oidcConf, err := m.getOpenIDConfig(ctx, oidc.GetIssuer()) if err != nil { return fmt.Errorf("getting oidc config: %w", err) } @@ -611,8 +609,8 @@ func (m *Manager) verifyAccessToken(accessToken, atHash string, idToken *jwt.JSO return nil } -func (m *Manager) httpClient(mapping trust.OIDCMapping) *http.Client { - creds := m.newCreds(m.getClientID(mapping)) +func (m *Manager) httpClient(oidc *oidcv1.OIDC) *http.Client { + creds := m.newCreds(m.getClientID(oidc)) transport := creds.Transport() if debugSettingSMDumpTransport.Value() == "1" { transport = debugtools.NewTransport(transport) @@ -623,25 +621,23 @@ func (m *Manager) httpClient(mapping trust.OIDCMapping) *http.Client { } } -func (m *Manager) getClientID(mapping trust.OIDCMapping) string { - if mapping.ClientID != "" { - return mapping.ClientID +func (m *Manager) getClientID(oidc *oidcv1.OIDC) string { + if clientID := oidc.GetClientId(); clientID != "" { + return clientID } return m.clientID } -func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configuration, code, codeVerifier string, mapping trust.OIDCMapping) (tokenResponse, error) { +func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configuration, code, codeVerifier string, oidc *oidcv1.OIDC) (tokenResponse, error) { data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("code", code) data.Set("code_verifier", codeVerifier) data.Set("redirect_uri", m.callbackURL.String()) - for _, parameter := range m.queryParametersToken { - value, ok := mapping.Properties[parameter] - if ok { - data.Set(parameter, value) - } + //nolint:forcetypeassert + for _, param := range proto.GetExtension(oidc, flowv1.E_TokenAttributes).([]*flowv1.Attribute) { + data.Set(param.GetKey(), param.GetValue()) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, openidConf.TokenEndpoint, strings.NewReader(data.Encode())) @@ -650,7 +646,7 @@ func (m *Manager) exchangeCode(ctx context.Context, openidConf *oidc.Configurati } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := m.httpClient(mapping) + client := m.httpClient(oidc) resp, err := client.Do(req) if err != nil { return tokenResponse{}, fmt.Errorf("executing request: %w", err) diff --git a/internal/session/manager_cookie_test.go b/internal/session/manager_cookie_test.go index 9927f821..5fac7436 100644 --- a/internal/session/manager_cookie_test.go +++ b/internal/session/manager_cookie_test.go @@ -7,11 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) func TestManager_MakeSessionCookie(t *testing.T) { @@ -426,33 +428,35 @@ func TestManager_Logout(t *testing.T) { tests := []struct { name string cfg *config.SessionManager - setupOIDCRepo func(t *testing.T) *trustmock.Repository + setupOIDCRepo func(t *testing.T) *mocktrust.Repository setupSession func(*sessionmock.Repository) wantErr bool wantErrMessage string wantURL string }{ { - name: "Success - redirect to postLogoutURL when no end session endpoint", + name: "Success - redirect to postLogoutURL", cfg: &config.SessionManager{ CSRFSecretParsed: []byte(testCSRFSecret), ClientAuth: config.ClientAuth{ ClientID: testClientID, }, }, - setupOIDCRepo: func(t *testing.T) *trustmock.Repository { + setupOIDCRepo: func(t *testing.T) *mocktrust.Repository { t.Helper() // StartOIDCServer doesn't include end_session_endpoint, so it will fall back to postLogoutURL oidcServer := StartOIDCServer(t, false) t.Cleanup(oidcServer.Close) - mapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - JWKSURI: oidcServer.URL + "/jwks", - Audiences: []string{"test"}, - Properties: map[string]string{}, - } - return trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, mapping)) + trust := trustv1.Trust_builder{ + TenantId: new(tenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new(oidcServer.URL + "/jwks"), + Audiences: []string{"test"}, + }.Build(), + }.Build() + return mocktrust.NewInMemRepository(mocktrust.WithTrust(trust)) }, setupSession: func(repo *sessionmock.Repository) { //nolint:errcheck @@ -467,15 +471,14 @@ func TestManager_Logout(t *testing.T) { { name: "Error - session not found", cfg: &config.SessionManager{ - CSRFSecretParsed: []byte(testCSRFSecret), - PostLogoutRedirectURL: postLogoutURL, + CSRFSecretParsed: []byte(testCSRFSecret), ClientAuth: config.ClientAuth{ ClientID: testClientID, }, }, - setupOIDCRepo: func(t *testing.T) *trustmock.Repository { + setupOIDCRepo: func(t *testing.T) *mocktrust.Repository { t.Helper() - return trustmock.NewInMemRepository() + return mocktrust.NewInMemRepository() }, setupSession: func(repo *sessionmock.Repository) { // Don't store session - will cause error @@ -491,14 +494,9 @@ func TestManager_Logout(t *testing.T) { tt.setupSession(sessionRepo) oidcRepo := tt.setupOIDCRepo(t) + trust := newTrust(oidcRepo) - m, err := session.NewManager(ctx, - tt.cfg, - oidcRepo, - sessionRepo, - nil, - session.WithAllowHttpScheme(true), - ) + m, err := session.NewManager(ctx, tt.cfg, trust, sessionRepo, nil, session.WithAllowHttpScheme(true)) require.NoError(t, err) logoutURL, err := m.Logout(t.Context(), sessionID, postLogoutURL) diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go index cae05e7b..ab79edad 100644 --- a/internal/session/manager_test.go +++ b/internal/session/manager_test.go @@ -20,14 +20,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustmock" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) const ( @@ -49,20 +50,19 @@ func TestManager_Auth(t *testing.T) { auditServer := StartAuditServer(t) defer auditServer.Close() - oidcMapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - Blocked: false, - JWKSURI: "http://jwks.example.com", - Audiences: []string{requestURI}, - Properties: map[string]string{ - "paramAuth1": "paramAuth1", - "paramToken1": "paramToken1", - }, - } + oidcMapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new("http://jwks.example.com"), + Audiences: []string{requestURI}, + }.Build(), + }.Build() tests := []struct { name string - oidc *trustmock.Repository + oidc *mocktrust.Repository sessions *sessionmock.Repository requestURI string cfg *config.SessionManager @@ -70,17 +70,16 @@ func TestManager_Auth(t *testing.T) { fingerprint string wantURL string errAssert assert.ErrorAssertionFunc - mapping trust.OIDCMapping + mapping *trustv1.Trust }{ { name: "Success", - oidc: trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcMapping)), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, cfg: &config.SessionManager{ - SessionDuration: time.Hour, - CallbackURL: callbackURL, - AdditionalQueryParametersAuthorize: []string{"paramAuth1"}, + SessionDuration: time.Hour, + CallbackURL: callbackURL, ClientAuth: config.ClientAuth{ ClientID: testClientID, }, @@ -88,14 +87,14 @@ func TestManager_Auth(t *testing.T) { }, tenantID: tenantID, fingerprint: "fingerprint", - wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256&redirect_uri=" + callbackURL + "&response_type=code&scope=openid+profile+email+groups&state=someState¶mAuth1=paramAuth1", + wantURL: oidcServer.URL + "/oauth2/authorize?client_id=my-client-id&code_challenge=someChallenge&code_challenge_method=S256&redirect_uri=" + callbackURL + "&response_type=code&scope=openid+profile+email+groups&state=someState", errAssert: assert.NoError, }, { name: "Get trust mapping error", - oidc: trustmock.NewInMemRepository( - trustmock.WithTrust(tenantID, oidcMapping), - trustmock.WithGetError(errors.New("failed to get trust mapping")), + oidc: mocktrust.NewInMemRepository( + mocktrust.WithTrust(oidcMapping), + mocktrust.WithGetError(errors.New("failed to get trust mapping")), ), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, @@ -111,7 +110,7 @@ func TestManager_Auth(t *testing.T) { }, { name: "Save state error", - oidc: trustmock.NewInMemRepository(trustmock.WithTrust(tenantID, oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcMapping)), sessions: sessionmock.NewInMemRepository(sessionmock.WithStoreStateError(errors.New("failed to save state"))), requestURI: requestURI, cfg: &config.SessionManager{ @@ -143,7 +142,7 @@ func TestManager_Auth(t *testing.T) { m, err := session.NewManager(ctx, tt.cfg, - tt.oidc, + newTrust(tt.oidc), tt.sessions, auditLogger, session.WithAllowHttpScheme(true), @@ -232,7 +231,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { tests := []struct { name string - oidc *trustmock.Repository + oidc *mocktrust.Repository sessions *sessionmock.Repository stateID string code string @@ -246,17 +245,15 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }{ { name: "Success", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, fingerprint: fingerprint, cfg: &config.SessionManager{ - SessionDuration: time.Hour, - CallbackURL: callbackURL, - AdditionalQueryParametersToken: []string{"queryParamToken1"}, - AdditionalAuthContextKeys: []string{"authContextKey1"}, - CSRFSecretParsed: []byte(testCSRFSecret), + SessionDuration: time.Hour, + CallbackURL: callbackURL, + CSRFSecretParsed: []byte(testCSRFSecret), }, wantSessionID: true, wantCSRFToken: true, @@ -265,7 +262,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "State load error", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithLoadStateError(errors.New("state not found"))), stateID: stateID, code: code, @@ -281,7 +278,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "State expired", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(expiredState)), stateID: stateID, code: code, @@ -296,7 +293,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "Fingerprint mismatch", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(mismatchState)), stateID: stateID, code: code, @@ -311,7 +308,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "Trust mapping get error", - oidc: trustmock.NewInMemRepository(trustmock.WithGetError(errors.New("trust mapping not found"))), + oidc: mocktrust.NewInMemRepository(mocktrust.WithGetError(errors.New("trust mapping not found"))), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, @@ -326,7 +323,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }, { name: "Token exchange error", - oidc: trustmock.NewInMemRepository(), + oidc: mocktrust.NewInMemRepository(), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, @@ -357,22 +354,21 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { jwksURI, err := url.JoinPath(oidcServer.URL, "/.well-known/jwks.json") require.NoError(t, err) - localOIDCMapping := trust.OIDCMapping{ - IssuerURL: oidcServer.URL, - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - Properties: map[string]string{ - "queryParamToken1": "queryParamToken1", - "authContextKey1": "authContextValue1", - }, - } + localOIDCMapping := trustv1.Trust_builder{ + TenantId: new(tenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(oidcServer.URL), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() - tt.oidc.TAdd(tenantID, localOIDCMapping) + tt.oidc.TAdd(localOIDCMapping) m, err := session.NewManager(ctx, tt.cfg, - tt.oidc, + newTrust(tt.oidc), tt.sessions, auditLogger, session.WithAllowHttpScheme(true), @@ -464,7 +460,7 @@ func TestManager_BCLogout(t *testing.T) { name string cfg *config.SessionManager jwt string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { @@ -478,10 +474,13 @@ func TestManager_BCLogout(t *testing.T) { Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, SessionID: "sid-1", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { - _ = oidcs.Create(context.Background(), "tid-1", trust.OIDCMapping{ - IssuerURL: jwksSrv.URL, - }) + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { + _ = oidcs.Create(context.Background(), trustv1.Trust_builder{ + TenantId: new("tid-1"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(jwksSrv.URL), + }.Build(), + }.Build()) _ = sessions.StoreSession(context.Background(), session.Session{ID: "sid-1", TenantID: "tid-1"}) }, errAssert: assert.NoError, @@ -499,7 +498,8 @@ func TestManager_BCLogout(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() rt := localRoundTripper{ @@ -522,7 +522,7 @@ func TestManager_BCLogout(t *testing.T) { m, err := session.NewManager(ctx, tt.cfg, - oidcMock, + trust, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt)), @@ -547,13 +547,13 @@ func TestManager_LogoutEdgeCases(t *testing.T) { tests := []struct { name string sessionID string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { name: "Session not found", sessionID: "non-existent", - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { _ = sessions.StoreSession(context.Background(), session.Session{ ID: sessionID, TenantID: tenantID, @@ -564,7 +564,7 @@ func TestManager_LogoutEdgeCases(t *testing.T) { { name: "Trust mapping not found", sessionID: sessionID, - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { _ = sessions.StoreSession(context.Background(), session.Session{ ID: sessionID, TenantID: tenantID, @@ -584,7 +584,8 @@ func TestManager_LogoutEdgeCases(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() tt.setupMock(oidcMock, sessionMock) @@ -596,7 +597,7 @@ func TestManager_LogoutEdgeCases(t *testing.T) { }, } - m, err := session.NewManager(ctx, cfg, oidcMock, sessionMock, auditLogger) + m, err := session.NewManager(ctx, cfg, trust, sessionMock, auditLogger) require.NoError(t, err) _, err = m.Logout(ctx, tt.sessionID, postLogoutURL) @@ -644,13 +645,13 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { tests := []struct { name string jwt string - setupMock func(*trustmock.Repository, *sessionmock.Repository) + setupMock func(*mocktrust.Repository, *sessionmock.Repository) errAssert assert.ErrorAssertionFunc }{ { name: "Invalid JWT", jwt: "invalid.jwt.token", - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -663,7 +664,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { Events: map[string]struct{}{"http://invalid-event": {}}, SessionID: "sid-1", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -674,7 +675,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { }{ Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.Error, }, @@ -687,7 +688,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { Events: map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, SessionID: "non-existent-session", }), - setupMock: func(oidcs *trustmock.Repository, sessions *sessionmock.Repository) { + setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { }, errAssert: assert.NoError, }, @@ -703,7 +704,8 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { auditLogger, err := otlpaudit.NewLogger(&commoncfg.Audit{Endpoint: auditServer.URL}) require.NoError(t, err) - oidcMock := trustmock.NewInMemRepository() + oidcMock := mocktrust.NewInMemRepository() + trust := newTrust(oidcMock) sessionMock := sessionmock.NewInMemRepository() rt := localRoundTripper{ @@ -722,7 +724,7 @@ func TestManager_BCLogout_ErrorCases(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - m, err := session.NewManager(ctx, cfg, oidcMock, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt))) + m, err := session.NewManager(ctx, cfg, trust, sessionMock, auditLogger, session.WithTransportCredentials(newTCBuilder(rt))) require.NoError(t, err) err = m.BCLogout(ctx, tt.jwt) @@ -744,7 +746,9 @@ func TestManager_NewManager_Error(t *testing.T) { CSRFSecretParsed: []byte(testCSRFSecret), } - m, err := session.NewManager(ctx, cfg, trustmock.NewInMemRepository(), sessionmock.NewInMemRepository(), auditLogger) + trust := newTrust(mocktrust.NewInMemRepository()) + + m, err := session.NewManager(ctx, cfg, trust, sessionmock.NewInMemRepository(), auditLogger) assert.Error(t, err) assert.Nil(t, m) assert.Contains(t, err.Error(), "parsing callback URL") diff --git a/internal/session/mock/repository.go b/internal/session/mock/repository.go index 71c85a31..6d5989c3 100644 --- a/internal/session/mock/repository.go +++ b/internal/session/mock/repository.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" ) diff --git a/internal/session/valkey/repository.go b/internal/session/valkey/repository.go index 58ef4cf5..87fa3e2f 100644 --- a/internal/session/valkey/repository.go +++ b/internal/session/valkey/repository.go @@ -10,7 +10,7 @@ import ( slogctx "github.com/veqryn/slog-context" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" ) diff --git a/internal/session/valkey/store.go b/internal/session/valkey/store.go index 1750951b..b92622ce 100644 --- a/internal/session/valkey/store.go +++ b/internal/session/valkey/store.go @@ -11,7 +11,7 @@ import ( "github.com/valkey-io/valkey-go" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type store struct { diff --git a/internal/trust/mapping.go b/internal/trust/mapping.go deleted file mode 100644 index b273807e..00000000 --- a/internal/trust/mapping.go +++ /dev/null @@ -1,25 +0,0 @@ -package trust - -type OIDCMapping struct { - IssuerURL string - Blocked bool - JWKSURI string - Audiences []string - Properties map[string]string - - // ClientID is a client_id property used for authentication. - // It is an optional value for the trust config. If the trust's client id is not specified, - // the application-global client id is used. - ClientID string -} - -func (p *OIDCMapping) GetIntrospectParameters(keys []string) map[string]string { - params := make(map[string]string, len(keys)) - for _, parameter := range keys { - value, ok := p.Properties[parameter] - if ok { - params[parameter] = value - } - } - return params -} diff --git a/internal/trust/mapping_test.go b/internal/trust/mapping_test.go deleted file mode 100644 index 9b4c6274..00000000 --- a/internal/trust/mapping_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package trust - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestOIDCMappingr_GetIntrospectParameters(t *testing.T) { - tests := []struct { - name string - oidcMapping OIDCMapping - keys []string - wantParams map[string]string - }{ - { - name: "returns matching parameters", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - "client_secret": "my-secret", - "scope": "openid", - }, - }, - keys: []string{"client_id", "client_secret"}, - wantParams: map[string]string{ - "client_id": "my-client-id", - "client_secret": "my-secret", - }, - }, - { - name: "skips missing parameters", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{"client_id", "missing_key"}, - wantParams: map[string]string{ - "client_id": "my-client-id", - }, - }, - { - name: "returns empty map when no keys provided", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{}, - wantParams: map[string]string{}, - }, - { - name: "returns empty map when properties is nil", - oidcMapping: OIDCMapping{ - Properties: nil, - }, - keys: []string{"client_id"}, - wantParams: map[string]string{}, - }, - { - name: "returns empty map when no keys match", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "client_id": "my-client-id", - }, - }, - keys: []string{"non_existent_key"}, - wantParams: map[string]string{}, - }, - { - name: "handles all keys matching", - oidcMapping: OIDCMapping{ - Properties: map[string]string{ - "key1": "value1", - "key2": "value2", - "key3": "value3", - }, - }, - keys: []string{"key1", "key2", "key3"}, - wantParams: map[string]string{ - "key1": "value1", - "key2": "value2", - "key3": "value3", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.oidcMapping.GetIntrospectParameters(tt.keys) - assert.Equal(t, tt.wantParams, got) - }) - } -} diff --git a/internal/trust/repository.go b/internal/trust/repository.go deleted file mode 100644 index a81d3472..00000000 --- a/internal/trust/repository.go +++ /dev/null @@ -1,11 +0,0 @@ -package trust - -import "context" - -// OIDCMappingRepository allows to read OIDC mapping data for a tenant stored in the context. -type OIDCMappingRepository interface { - Get(ctx context.Context, tenantID string) (OIDCMapping, error) - Create(ctx context.Context, tenantID string, mapping OIDCMapping) error - Delete(ctx context.Context, tenantID string) error - Update(ctx context.Context, tenantID string, mapping OIDCMapping) error -} diff --git a/internal/trust/service.go b/internal/trust/service.go deleted file mode 100644 index 8c4a2686..00000000 --- a/internal/trust/service.go +++ /dev/null @@ -1,95 +0,0 @@ -package trust - -import ( - "context" - "errors" - "fmt" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -type Service struct { - repository OIDCMappingRepository -} - -func NewService(repo OIDCMappingRepository) *Service { - return &Service{ - repository: repo, - } -} - -func (s *Service) ApplyMapping(ctx context.Context, tenantID string, mapping OIDCMapping) error { - _, err := s.repository.Get(ctx, tenantID) - if err != nil { - err = s.repository.Create(ctx, tenantID, mapping) - if err != nil { - return fmt.Errorf("creating mapping for tenant: %w", err) - } - } else { - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - return fmt.Errorf("updating mapping for tenant: %w", err) - } - } - - return nil -} - -// BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. -// If the mapping is already blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. -func (s *Service) BlockMapping(ctx context.Context, tenantID string) error { - mapping, err := s.repository.Get(ctx, tenantID) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("getting mapping for tenant: %w", err) - } - if mapping.Blocked { - return nil - } - mapping.Blocked = true - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("updating mapping for blocking tenant: %w", err) - } - return nil -} - -func (s *Service) RemoveMapping(ctx context.Context, tenantID string) error { - err := s.repository.Delete(ctx, tenantID) - if err != nil { - return fmt.Errorf("deleting mapping for tenant: %w", err) - } - - return nil -} - -// UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. -// If the mapping is not blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. -func (s *Service) UnblockMapping(ctx context.Context, tenantID string) error { - mapping, err := s.repository.Get(ctx, tenantID) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("getting mapping for tenant: %w", err) - } - if !mapping.Blocked { - return nil - } - mapping.Blocked = false - err = s.repository.Update(ctx, tenantID, mapping) - if err != nil { - if errors.Is(err, serviceerr.ErrNotFound) { - return nil - } - return fmt.Errorf("updating mapping for unblocking tenant: %w", err) - } - return nil -} diff --git a/internal/trust/service_test.go b/internal/trust/service_test.go deleted file mode 100644 index f216d22e..00000000 --- a/internal/trust/service_test.go +++ /dev/null @@ -1,540 +0,0 @@ -package trust_test - -import ( - "context" - "os" - "testing" - - "github.com/gofrs/uuid/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/trust" -) - -var repo trust.OIDCMappingRepository - -const ( - requestURI = "http://cmk.example.com/ui" - jwksURI = "http://jwks.example.com" -) - -func TestMain(m *testing.M) { - ctx := context.Background() - r, err := createRepo(ctx) - if err != nil { - slogctx.Error(ctx, "error while creating repo", "error", err) - } - - repo = r - - code := m.Run() - os.Exit(code) -} - -func TestService_ApplyMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if", func(t *testing.T) { - t.Run("the mapping does not exist", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expMapping, actMapping) - }) - - t.Run("the mapping exists", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - - expUpdatedMapping := trust.OIDCMapping{ - IssuerURL: expMapping.IssuerURL, - JWKSURI: "http://updated-jwks.example.com", - Audiences: []string{requestURI, "http://new-aud.example.com"}, - } - - err = subj.ApplyMapping(ctx, expTenantID, expUpdatedMapping) - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expUpdatedMapping, actMapping) - }) - }) - - t.Run("should return error if", func(t *testing.T) { - t.Run("Create returns an error", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - noOfCalls := 0 - wrapper.MockCreate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - assert.Equal(t, expMapping, mapping) - noOfCalls++ - return assert.AnError - } - - subj := trust.NewService(wrapper) - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfCalls) - }) - - t.Run("Update returns an error", func(t *testing.T) { - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - noOfCalls := 0 - wrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - assert.Equal(t, expMapping, mapping) - noOfCalls++ - return assert.AnError - } - subj := trust.NewService(wrapper) - - err := subj.ApplyMapping(ctx, expTenantID, expMapping) - assert.NoError(t, err) - err = subj.ApplyMapping(ctx, expTenantID, expMapping) - - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfCalls) - }) - }) -} - -func TestService_BlockMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if ", func(t *testing.T) { - t.Run("the mapping is unblocked", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - subj := trust.NewService(wrapper) - - // when - err = subj.BlockMapping(ctx, expTenantID) - - // then - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.True(t, actMapping.Blocked) - assert.Equal(t, expUnblockedMapping.IssuerURL, actMapping.IssuerURL) - assert.Equal(t, expUnblockedMapping.Audiences, actMapping.Audiences) - assert.Equal(t, expUnblockedMapping.JWKSURI, actMapping.JWKSURI) - }) - - t.Run("the mapping is blocked then it should not call Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 0, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) - }) - t.Run("the mapping is not found during the Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - // delete the mapping before updating to return an error - err := repoWrapper.Repo.Delete(ctx, expTenantID) - assert.NoError(t, err) - return nil - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 1, noOfUpdateCalls) - }) - t.Run("the mapping is not found", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - subj := trust.NewService(repoWrapper) - - // when - err := subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - }) - }) - - t.Run("should return error", func(t *testing.T) { - t.Run("if Get returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - noOfGetCalls := 0 - repoWrapper.MockGet = func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - assert.Equal(t, expTenantID, tenantID) - noOfGetCalls++ - return trust.OIDCMapping{}, assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err := subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfGetCalls) - }) - - t.Run("if Update returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantID, tenantID) - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.BlockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.Equal(t, expMapping, actMapping) - }) - }) -} - -func TestService_UnblockMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if ", func(t *testing.T) { - t.Run("the mapping is blocked", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expBlockedMapping) - require.NoError(t, err) - subj := trust.NewService(wrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.False(t, actMapping.Blocked) - assert.Equal(t, expBlockedMapping.IssuerURL, actMapping.IssuerURL) - assert.Equal(t, expBlockedMapping.Audiences, actMapping.Audiences) - assert.Equal(t, expBlockedMapping.JWKSURI, actMapping.JWKSURI) - }) - - t.Run("the mapping is unblocked then it should not call Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: false, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 0, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) - assert.NoError(t, err) - assert.False(t, actMapping.Blocked) - }) - t.Run("the mapping is not found during the Update", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantID, expUnblockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - noOfUpdateCalls++ - // delete the mapping before updating to return an error - err := repoWrapper.Repo.Delete(ctx, expTenantID) - assert.NoError(t, err) - return nil - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - assert.Equal(t, 1, noOfUpdateCalls) - }) - t.Run("the mapping is not found", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - repoWrapper := &RepoWrapper{Repo: repo} - - subj := trust.NewService(repoWrapper) - - // when - err := subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.NoError(t, err) - }) - }) - t.Run("should return error", func(t *testing.T) { - t.Run("if Get returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - mockRepo := &RepoWrapper{Repo: repo} - - noOfGetTenantCalls := 0 - mockRepo.MockGet = func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - assert.Equal(t, expTenantID, tenantID) - noOfGetTenantCalls++ - return trust.OIDCMapping{}, assert.AnError - } - subj := trust.NewService(mockRepo) - - // when - err := subj.UnblockMapping(t.Context(), expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfGetTenantCalls) - }) - - t.Run("if Update returns an error", func(t *testing.T) { - // given - expTenantIDtoUpdate := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - Blocked: true, - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expTenantIDtoUpdate, expBlockedMapping) - require.NoError(t, err) - - noOfUpdateCalls := 0 - repoWrapper.MockUpdate = func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - assert.Equal(t, expTenantIDtoUpdate, tenantID) - noOfUpdateCalls++ - return assert.AnError - } - subj := trust.NewService(repoWrapper) - - // when - err = subj.UnblockMapping(t.Context(), expTenantIDtoUpdate) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfUpdateCalls) - - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) - assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) - }) - }) -} - -func TestService_RemoveMapping(t *testing.T) { - ctx := t.Context() - - t.Run("success if", func(t *testing.T) { - t.Run("the mapping exists", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trust.OIDCMapping{ - IssuerURL: uuid.Must(uuid.NewV4()).String(), - JWKSURI: jwksURI, - Audiences: []string{requestURI}, - } - - wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expTenantID, expMapping) - require.NoError(t, err) - - subj := trust.NewService(wrapper) - - // when - err = subj.RemoveMapping(ctx, expTenantID) - - // then - assert.NoError(t, err) - - // verify the mapping was deleted - _, err = wrapper.Repo.Get(ctx, expTenantID) - assert.Error(t, err) - }) - }) - - t.Run("should return error if", func(t *testing.T) { - t.Run("the mapping does not exist", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - wrapper := &RepoWrapper{Repo: repo} - subj := trust.NewService(wrapper) - - // when - err := subj.RemoveMapping(ctx, expTenantID) - - // then - assert.Error(t, err) - }) - - t.Run("Delete returns an error", func(t *testing.T) { - // given - expTenantID := uuid.Must(uuid.NewV4()).String() - wrapper := &RepoWrapper{Repo: repo} - - noOfDeleteCalls := 0 - wrapper.MockDelete = func(ctx context.Context, tenantID string) error { - assert.Equal(t, expTenantID, tenantID) - noOfDeleteCalls++ - return assert.AnError - } - - subj := trust.NewService(wrapper) - - // when - err := subj.RemoveMapping(ctx, expTenantID) - - // then - assert.ErrorIs(t, err, assert.AnError) - assert.Equal(t, 1, noOfDeleteCalls) - }) - }) -} diff --git a/internal/trust/trustsql/errors.go b/internal/trust/trustsql/errors.go deleted file mode 100644 index 31eff64b..00000000 --- a/internal/trust/trustsql/errors.go +++ /dev/null @@ -1,18 +0,0 @@ -package trustsql - -import ( - "errors" - - "github.com/jackc/pgx/v5/pgconn" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -func handlePgError(err error) (error, bool) { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == "23505" { - return serviceerr.ErrConflict, true - } - - return err, false -} diff --git a/internal/trust/trustsql/errors_test.go b/internal/trust/trustsql/errors_test.go deleted file mode 100644 index 7d339eb3..00000000 --- a/internal/trust/trustsql/errors_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package trustsql - -import ( - "errors" - "testing" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/serviceerr" -) - -var errUnknown = errors.New("unknown error") - -func Test_handlePgError(t *testing.T) { - tests := []struct { - name string - inputErr error - errTarget error - wantOk bool - }{ - { - name: "23505 error", - inputErr: &pgconn.PgError{Code: "23505"}, - errTarget: serviceerr.ErrConflict, - wantOk: true, - }, - { - name: "Unknown error", - inputErr: errUnknown, - errTarget: errUnknown, - wantOk: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotErr, ok := handlePgError(tt.inputErr) - if !assert.ErrorIsf(t, gotErr, tt.errTarget, "handlePgError() error %v", gotErr) { - return - } - - if !assert.Equalf(t, tt.wantOk, ok, "handlePgError() OK = %v, want = %v", ok, tt.wantOk) { - return - } - }) - } -} diff --git a/internal/trust/trustsql/internal/queries/models.go b/internal/trust/trustsql/internal/queries/models.go deleted file mode 100644 index 8b643452..00000000 --- a/internal/trust/trustsql/internal/queries/models.go +++ /dev/null @@ -1,20 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.31.1 - -package queries - -import ( - "github.com/jackc/pgx/v5/pgtype" -) - -type Trust struct { - TenantID string `db:"tenant_id"` - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - CreatedAt pgtype.Timestamp `db:"created_at"` - ClientID pgtype.Text `db:"client_id"` -} diff --git a/internal/trust/trustsql/repository.go b/internal/trust/trustsql/repository.go deleted file mode 100644 index b2d71625..00000000 --- a/internal/trust/trustsql/repository.go +++ /dev/null @@ -1,157 +0,0 @@ -package trustsql - -import ( - "context" - "encoding/json" - "errors" - "fmt" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - "go.opentelemetry.io/otel" - - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql/internal/queries" -) - -type Repository struct { - db *pgxpool.Pool - queries *queries.Queries -} - -func NewRepository(db *pgxpool.Pool) *Repository { - return &Repository{ - db: db, - queries: queries.New(db), - } -} - -func (r *Repository) Get(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_mapping_sql") - defer span.End() - - row, err := r.queries.GetOIDCMapping(ctx, tenantID) - if err != nil { - span.RecordError(err) - if errors.Is(err, pgx.ErrNoRows) { - return trust.OIDCMapping{}, serviceerr.ErrNotFound - } - - return trust.OIDCMapping{}, err - } - - properties := make(map[string]string) - if len(row.Properties) > 0 { - err := json.Unmarshal(row.Properties, &properties) - if err != nil { - return trust.OIDCMapping{}, fmt.Errorf("unmarshalling properties: %w", err) - } - } - - return trust.OIDCMapping{ - IssuerURL: row.Issuer, - Blocked: row.Blocked, - JWKSURI: row.JwksUri, - Audiences: row.Audiences, - Properties: properties, - ClientID: row.ClientID.String, - }, nil -} - -func (r *Repository) Create(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "create_oidc_mapping_sql") - defer span.End() - - properties, err := r.marshalProperties(mapping) - if err != nil { - return fmt.Errorf("marshaling properties: %w", err) - } - - if err := r.queries.CreateOIDCMapping(ctx, queries.CreateOIDCMappingParams{ - TenantID: tenantID, - Blocked: mapping.Blocked, - Issuer: mapping.IssuerURL, - JwksUri: mapping.JWKSURI, - Audiences: mapping.Audiences, - Properties: properties, - ClientID: pgTextOrNull(mapping.ClientID), - }); err != nil { - span.RecordError(err) - if err, ok := handlePgError(err); ok { - return err - } - - return fmt.Errorf("inserting into trust: %w", err) - } - - return nil -} - -func (r *Repository) Delete(ctx context.Context, tenantID string) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "delete_oidc_mapping_sql") - defer span.End() - - affected, err := r.queries.DeleteOIDCMapping(ctx, tenantID) - if err != nil { - span.RecordError(err) - return fmt.Errorf("executing sql query: %w", err) - } - - if affected == 0 { - return serviceerr.ErrNotFound - } - - return nil -} - -func (r *Repository) Update(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { - tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "update_oidc_mapping_sql") - defer span.End() - - properties, err := r.marshalProperties(mapping) - if err != nil { - span.RecordError(err) - return err - } - - affected, err := r.queries.UpdateOIDCMapping(ctx, queries.UpdateOIDCMappingParams{ - Blocked: mapping.Blocked, - Issuer: mapping.IssuerURL, - JwksUri: mapping.JWKSURI, - Audiences: mapping.Audiences, - Properties: properties, - ClientID: pgTextOrNull(mapping.ClientID), - TenantID: tenantID, - }) - if err != nil { - span.RecordError(err) - return fmt.Errorf("updating trust: %w", err) - } - - if affected == 0 { - return serviceerr.ErrNotFound - } - - return nil -} - -func (r *Repository) marshalProperties(mapping trust.OIDCMapping) ([]byte, error) { - propsBytes, err := json.Marshal(mapping.Properties) - if err != nil { - return nil, fmt.Errorf("marshaling json: %w", err) - } - return propsBytes, nil -} - -func pgTextOrNull(s string) pgtype.Text { - return pgtype.Text{ - String: s, - Valid: s != "", - } -} diff --git a/internal/trust/trustsql/repository_test.go b/internal/trust/trustsql/repository_test.go deleted file mode 100644 index 74169969..00000000 --- a/internal/trust/trustsql/repository_test.go +++ /dev/null @@ -1,302 +0,0 @@ -package trustsql_test - -import ( - "context" - "errors" - "fmt" - "os" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" -) - -var dbPool *pgxpool.Pool - -func TestMain(m *testing.M) { - ctx := context.Background() - - pool, _, terminate := postgrestest.Start(ctx) - defer terminate(ctx) - - dbPool = pool - - code := m.Run() - os.Exit(code) -} - -func TestRepository_Get(t *testing.T) { - tests := []struct { - name string - tenantID string - wantMapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Success", - tenantID: "tenant1-id", - wantMapping: trust.OIDCMapping{ - IssuerURL: "url-one", - Blocked: false, - JWKSURI: "", - Audiences: make([]string, 0), - Properties: make(map[string]string), - }, - assertErr: assert.NoError, - }, - { - name: "Error does not exist", - tenantID: "does-not-exist", - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := trustsql.NewRepository(dbPool) - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Get() error %v", err)) || err != nil { - assert.Zerof(t, gotMapping, "Repository.Get() extected zero value if an error is returned, got %v", gotMapping) - return - } - - assert.Equal(t, tt.wantMapping, gotMapping, "Repository.Get()") - }) - } -} - -func TestRepository_Create(t *testing.T) { - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Create succeeds", - tenantID: "tenant-id-create-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - Properties: map[string]string{ - "prop1": "prop1val", - }, - }, - assertErr: assert.NoError, - }, - { - name: "Duplicate", - tenantID: "tenant1-id", - mapping: trust.OIDCMapping{ - IssuerURL: "url-one", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - Properties: map[string]string{ - "prop1": "prop1val", - }, - }, - assertErr: assert.Error, - }, - { - name: "Create without JWKSURI and Audiences succeeds", - tenantID: "tenant-id-create-without-jwks-aud-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-2.example.com", - Blocked: false, - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - { - name: "Create without JWKSURI succeeds", - tenantID: "tenant-id-create-without-jwks-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-3.example.com", - Blocked: false, - Audiences: []string{"cmk.example.com"}, - }, - assertErr: assert.NoError, - }, - { - name: "Create without Audiences succeeds", - tenantID: "tenant-id-create-without-aud-success", - mapping: trust.OIDCMapping{ - IssuerURL: "http://oidc-success-4.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Given - r := trustsql.NewRepository(dbPool) - - // When - err := r.Create(t.Context(), tt.tenantID, tt.mapping) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Create() error %v", err)) || err != nil { - return - } - - // Then - mapping, err := r.Get(t.Context(), tt.tenantID) - require.NoError(t, err) - - if diff := cmp.Diff(tt.mapping, mapping); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) - } - }) - } -} - -func TestRepository_Delete(t *testing.T) { - const tenantID = "tenant-id-delete-success" - - mapping := trust.OIDCMapping{ - IssuerURL: "http://oidc-to-delete.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - } - - r := trustsql.NewRepository(dbPool) - err := r.Create(t.Context(), tenantID, mapping) - require.NoError(t, err, "Inserting test data") - - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Delete tenant", - tenantID: tenantID, - mapping: mapping, - assertErr: assert.NoError, - }, - { - name: "Error does not exist", - tenantID: "does-not-exist", - mapping: trust.OIDCMapping{IssuerURL: "does-not-exist"}, - assertErr: assert.Error, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := r.Delete(t.Context(), tt.tenantID) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Delete() error %v", err)) || err != nil { - return - } - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - if !errors.Is(err, serviceerr.ErrNotFound) { - t.Error("The mapping is expected to be deleted") - } - assert.Zero(t, gotMapping, "The mapping is expected to be deleted, instead a value is returned") - }) - } -} - -func TestRepository_Update(t *testing.T) { - const tenantID = "tenant-id-update-success" - - mapping := trust.OIDCMapping{ - IssuerURL: "http://oidc-to-update.example.com", - Blocked: false, - JWKSURI: "jwks.example.com", - Audiences: []string{"cmk.example.com"}, - } - - r := trustsql.NewRepository(dbPool) - err := r.Create(t.Context(), tenantID, mapping) - require.NoError(t, err, "Inserting test data") - - tests := []struct { - name string - tenantID string - mapping trust.OIDCMapping - assertErr assert.ErrorAssertionFunc - }{ - { - name: "Update succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.NoError, - }, - { - name: "Does not exist", - tenantID: "does-not-exist", - mapping: trust.OIDCMapping{ - IssuerURL: "does-not-exist", - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.Error, - }, - { - name: "Update without JWKSURI and Audiences succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - { - name: "Update without JWKSURI succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - Audiences: append(mapping.Audiences, "new-audience.example.com"), - }, - assertErr: assert.NoError, - }, - { - name: "Update without Audiences succeeds", - tenantID: tenantID, - mapping: trust.OIDCMapping{ - IssuerURL: mapping.IssuerURL, - Blocked: true, - JWKSURI: "jwks-updated.example.com", - Audiences: []string{}, - }, - assertErr: assert.NoError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := r.Update(t.Context(), tt.tenantID, tt.mapping) - if !tt.assertErr(t, err, fmt.Sprintf("Repository.Update() error %v", err)) || err != nil { - return - } - - gotMapping, err := r.Get(t.Context(), tt.tenantID) - require.NoError(t, err) - - if diff := cmp.Diff(tt.mapping, gotMapping); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) - } - }) - } -} diff --git a/modules.go b/modules.go new file mode 100644 index 00000000..2f3a4f25 --- /dev/null +++ b/modules.go @@ -0,0 +1,57 @@ +package sessionmanager + +import ( + "fmt" + "iter" + "maps" + "sync" +) + +var ( + modules = make(map[string]ModuleInfo) + modulesMu sync.RWMutex +) + +func RegisterModule(module Module) { + modulesMu.Lock() + defer modulesMu.Unlock() + + info := module.Module() + + if _, ok := modules[info.ID]; ok { + panic(`module "` + info.ID + `" has already been registered`) + } + + modules[info.ID] = info +} + +func GetModule(id string) (ModuleInfo, error) { + modulesMu.RLock() + defer modulesMu.RUnlock() + mod, ok := modules[id] + if !ok { + return ModuleInfo{}, fmt.Errorf("module %q is not registered", id) + } + + return mod, nil +} + +func Modules() iter.Seq[ModuleInfo] { + modulesMu.RLock() + defer modulesMu.RUnlock() + + return maps.Values(maps.Clone(modules)) +} + +type Module interface { + Module() ModuleInfo +} + +type ModuleInfo struct { + ID string + New func() Module +} + +type Provisioner interface { + Provision(ctx *Context) error +} diff --git a/modules/database/pgxpool/module.go b/modules/database/pgxpool/module.go new file mode 100644 index 00000000..f8865972 --- /dev/null +++ b/modules/database/pgxpool/module.go @@ -0,0 +1,107 @@ +package pgxpool + +import ( + "context" + "database/sql" + "fmt" + + "github.com/exaring/otelpgx" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "database.module.pgxpool" + +func newModule() sessionmanager.Module { + return new(PostgresModule) +} + +func init() { + sessionmanager.RegisterModule(new(PostgresModule)) +} + +type PostgresModule struct { + Mod string `yaml:"module"` + Name string `yaml:"name"` + Port string `yaml:"port"` + Host commoncfg.SourceRef `yaml:"host"` + User commoncfg.SourceRef `yaml:"user"` + Password commoncfg.SourceRef `yaml:"password"` + + db *pgxpool.Pool +} + +func (m *PostgresModule) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(m.db) +} + +func (m *PostgresModule) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + return m.db.Exec(ctx, sql, args...) +} + +func (m *PostgresModule) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return m.db.Query(ctx, sql, args...) +} + +func (m *PostgresModule) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return m.db.QueryRow(ctx, sql, args...) +} + +func (m *PostgresModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *PostgresModule) Provision(ctx *sessionmanager.Context) error { + connStr, err := m.makeConnStr() + if err != nil { + return fmt.Errorf("making dsn from config: %w", err) + } + + pgxpoolCfg, err := pgxpool.ParseConfig(connStr) + if err != nil { + return fmt.Errorf("parsing pgxpool config: %w", err) + } + + pgxpoolCfg.ConnConfig.Tracer = otelpgx.NewTracer() + + m.db, err = pgxpool.NewWithConfig(ctx, pgxpoolCfg) + if err != nil { + return fmt.Errorf("failed to initialise pgxpool connection: %w", err) + } + + if err := otelpgx.RecordStats(m.db); err != nil { + return fmt.Errorf("recording database stat: %w", err) + } + + return nil +} + +func (m *PostgresModule) makeConnStr() (string, error) { + host, err := commoncfg.LoadValueFromSourceRef(m.Host) + if err != nil { + return "", fmt.Errorf("loading db host: %w", err) + } + + user, err := commoncfg.LoadValueFromSourceRef(m.User) + if err != nil { + return "", fmt.Errorf("loading db user: %w", err) + } + + password, err := commoncfg.LoadValueFromSourceRef(m.Password) + if err != nil { + return "", fmt.Errorf("loading db password: %w", err) + } + + return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s", + host, user, string(password), m.Name, m.Port), nil +} + +var _ sessionmanager.Database = (*PostgresModule)(nil) diff --git a/modules/oidctrust/export.go b/modules/oidctrust/export.go new file mode 100644 index 00000000..cb30b5cb --- /dev/null +++ b/modules/oidctrust/export.go @@ -0,0 +1,15 @@ +package oidctrust + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" +) + +//nolint:unused +//go:linkname newOIDCTrustModuleWithRepo +func newOIDCTrustModuleWithRepo(r TrustRepository) sessionmanager.Trust { + return &TrustModule{ + repository: r, + } +} diff --git a/modules/oidctrust/export_test.go b/modules/oidctrust/export_test.go new file mode 100644 index 00000000..c4cb26ac --- /dev/null +++ b/modules/oidctrust/export_test.go @@ -0,0 +1,7 @@ +package oidctrust + +func NewModule(repo TrustRepository) *TrustModule { + return &TrustModule{ + repository: repo, + } +} diff --git a/modules/oidctrust/internal/sql/export_test.go b/modules/oidctrust/internal/sql/export_test.go new file mode 100644 index 00000000..1ec027ca --- /dev/null +++ b/modules/oidctrust/internal/sql/export_test.go @@ -0,0 +1,15 @@ +package sqltrust + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +// PgTextOrNull exposes pgTextOrNull for testing. +func PgTextOrNull(s string) pgtype.Text { + return pgTextOrNull(s) +} + +// HandlePgError exposes handlePgError for testing. +func HandlePgError(err error) (error, bool) { + return handlePgError(err) +} diff --git a/internal/trust/trustsql/queries.sql b/modules/oidctrust/internal/sql/queries.sql similarity index 90% rename from internal/trust/trustsql/queries.sql rename to modules/oidctrust/internal/sql/queries.sql index 93e07ea6..5fdc1244 100644 --- a/internal/trust/trustsql/queries.sql +++ b/modules/oidctrust/internal/sql/queries.sql @@ -4,7 +4,6 @@ SELECT blocked, jwks_uri, audiences, - properties, client_id FROM trust WHERE tenant_id = sqlc.arg(tenant_id); @@ -16,7 +15,6 @@ INSERT INTO trust ( issuer, jwks_uri, audiences, - properties, client_id) VALUES ( sqlc.arg(tenant_id), @@ -24,7 +22,6 @@ VALUES ( sqlc.arg(issuer), sqlc.arg(jwks_uri), COALESCE(sqlc.arg(audiences)::text[], '{}'::text[]), - sqlc.arg(properties), sqlc.arg(client_id)); -- name: DeleteOIDCMapping :execrows @@ -38,7 +35,6 @@ SET issuer = sqlc.arg(issuer), jwks_uri = sqlc.arg(jwks_uri), audiences = COALESCE(sqlc.arg(audiences)::text[], '{}'::text[]), - properties = sqlc.arg(properties), client_id = sqlc.arg(client_id) WHERE tenant_id = sqlc.arg(tenant_id); diff --git a/internal/trust/trustsql/internal/queries/db.go b/modules/oidctrust/internal/sql/queries/db.go similarity index 100% rename from internal/trust/trustsql/internal/queries/db.go rename to modules/oidctrust/internal/sql/queries/db.go diff --git a/modules/oidctrust/internal/sql/queries/models.go b/modules/oidctrust/internal/sql/queries/models.go new file mode 100644 index 00000000..77a9f77e --- /dev/null +++ b/modules/oidctrust/internal/sql/queries/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package queries + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Trust struct { + TenantID string `db:"tenant_id"` + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + CreatedAt pgtype.Timestamp `db:"created_at"` + ClientID pgtype.Text `db:"client_id"` +} diff --git a/internal/trust/trustsql/internal/queries/queries.sql.go b/modules/oidctrust/internal/sql/queries/queries.sql.go similarity index 68% rename from internal/trust/trustsql/internal/queries/queries.sql.go rename to modules/oidctrust/internal/sql/queries/queries.sql.go index 9f71b1e8..7cea03e6 100644 --- a/internal/trust/trustsql/internal/queries/queries.sql.go +++ b/modules/oidctrust/internal/sql/queries/queries.sql.go @@ -18,7 +18,6 @@ INSERT INTO trust ( issuer, jwks_uri, audiences, - properties, client_id) VALUES ( $1, @@ -26,18 +25,16 @@ VALUES ( $3, $4, COALESCE($5::text[], '{}'::text[]), - $6, - $7) + $6) ` type CreateOIDCMappingParams struct { - TenantID string `db:"tenant_id"` - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` + TenantID string `db:"tenant_id"` + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` } func (q *Queries) CreateOIDCMapping(ctx context.Context, arg CreateOIDCMappingParams) error { @@ -47,7 +44,6 @@ func (q *Queries) CreateOIDCMapping(ctx context.Context, arg CreateOIDCMappingPa arg.Issuer, arg.JwksUri, arg.Audiences, - arg.Properties, arg.ClientID, ) return err @@ -72,19 +68,17 @@ SELECT blocked, jwks_uri, audiences, - properties, client_id FROM trust WHERE tenant_id = $1 ` type GetOIDCMappingRow struct { - Issuer string `db:"issuer"` - Blocked bool `db:"blocked"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` + Issuer string `db:"issuer"` + Blocked bool `db:"blocked"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` } func (q *Queries) GetOIDCMapping(ctx context.Context, tenantID string) (GetOIDCMappingRow, error) { @@ -95,7 +89,6 @@ func (q *Queries) GetOIDCMapping(ctx context.Context, tenantID string) (GetOIDCM &i.Blocked, &i.JwksUri, &i.Audiences, - &i.Properties, &i.ClientID, ) return i, err @@ -108,20 +101,18 @@ SET issuer = $2, jwks_uri = $3, audiences = COALESCE($4::text[], '{}'::text[]), - properties = $5, - client_id = $6 + client_id = $5 WHERE - tenant_id = $7 + tenant_id = $6 ` type UpdateOIDCMappingParams struct { - Blocked bool `db:"blocked"` - Issuer string `db:"issuer"` - JwksUri string `db:"jwks_uri"` - Audiences []string `db:"audiences"` - Properties []byte `db:"properties"` - ClientID pgtype.Text `db:"client_id"` - TenantID string `db:"tenant_id"` + Blocked bool `db:"blocked"` + Issuer string `db:"issuer"` + JwksUri string `db:"jwks_uri"` + Audiences []string `db:"audiences"` + ClientID pgtype.Text `db:"client_id"` + TenantID string `db:"tenant_id"` } func (q *Queries) UpdateOIDCMapping(ctx context.Context, arg UpdateOIDCMappingParams) (int64, error) { @@ -130,7 +121,6 @@ func (q *Queries) UpdateOIDCMapping(ctx context.Context, arg UpdateOIDCMappingPa arg.Issuer, arg.JwksUri, arg.Audiences, - arg.Properties, arg.ClientID, arg.TenantID, ) diff --git a/modules/oidctrust/internal/sql/sql.go b/modules/oidctrust/internal/sql/sql.go new file mode 100644 index 00000000..8a2a184a --- /dev/null +++ b/modules/oidctrust/internal/sql/sql.go @@ -0,0 +1,154 @@ +package sqltrust + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "go.opentelemetry.io/otel" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/pkg/serviceerr" + "github.com/openkcm/session-manager/modules/oidctrust/internal/sql/queries" +) + +type Repository struct { + queries *queries.Queries +} + +func NewRepository(db sessionmanager.Database) *Repository { + return &Repository{ + queries: queries.New(db), + } +} + +func (r *Repository) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_mapping_sql") + defer span.End() + + row, err := r.queries.GetOIDCMapping(ctx, tenantID) + if err != nil { + span.RecordError(err) + if errors.Is(err, pgx.ErrNoRows) { + return nil, serviceerr.ErrNotFound + } + + return nil, err + } + + trust := trustv1.Trust_builder{ + TenantId: &tenantID, + Blocked: &row.Blocked, + Oidc: oidcv1.OIDC_builder{ + Audiences: row.Audiences, + }.Build(), + }.Build() + + if row.Issuer != "" { + trust.GetOidc().SetIssuer(row.Issuer) + } + + if row.JwksUri != "" { + trust.GetOidc().SetJwksUri(row.JwksUri) + } + + if row.ClientID.Valid { + trust.GetOidc().SetClientId(row.ClientID.String) + } + + return trust, nil +} + +func (r *Repository) Create(ctx context.Context, trust *trustv1.Trust) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "create_oidc_mapping_sql") + defer span.End() + + oidc := trust.GetOidc() + + if err := r.queries.CreateOIDCMapping(ctx, queries.CreateOIDCMappingParams{ + TenantID: trust.GetTenantId(), + Blocked: trust.GetBlocked(), + Issuer: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), + ClientID: pgTextOrNull(trust.GetOidc().GetClientId()), + }); err != nil { + span.RecordError(err) + if err, ok := handlePgError(err); ok { + return err + } + + return fmt.Errorf("inserting into trust: %w", err) + } + + return nil +} + +func (r *Repository) Delete(ctx context.Context, tenantID string) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "delete_oidc_mapping_sql") + defer span.End() + + affected, err := r.queries.DeleteOIDCMapping(ctx, tenantID) + if err != nil { + span.RecordError(err) + return fmt.Errorf("executing sql query: %w", err) + } + + if affected == 0 { + return serviceerr.ErrNotFound + } + + return nil +} + +func (r *Repository) Update(ctx context.Context, trust *trustv1.Trust) error { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "update_oidc_mapping_sql") + defer span.End() + + oidc := trust.GetOidc() + + affected, err := r.queries.UpdateOIDCMapping(ctx, queries.UpdateOIDCMappingParams{ + Blocked: trust.GetBlocked(), + Issuer: oidc.GetIssuer(), + JwksUri: oidc.GetJwksUri(), + Audiences: oidc.GetAudiences(), + ClientID: pgTextOrNull(oidc.GetClientId()), + TenantID: trust.GetTenantId(), + }) + if err != nil { + span.RecordError(err) + return fmt.Errorf("updating trust: %w", err) + } + + if affected == 0 { + return serviceerr.ErrNotFound + } + + return nil +} + +func pgTextOrNull(s string) pgtype.Text { + return pgtype.Text{ + String: s, + Valid: s != "", + } +} + +func handlePgError(err error) (error, bool) { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return serviceerr.ErrConflict, true + } + + return err, false +} diff --git a/modules/oidctrust/internal/sql/sql_test.go b/modules/oidctrust/internal/sql/sql_test.go new file mode 100644 index 00000000..7742818d --- /dev/null +++ b/modules/oidctrust/internal/sql/sql_test.go @@ -0,0 +1,305 @@ +package sqltrust_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/dbtest/postgrestest" + "github.com/openkcm/session-manager/pkg/serviceerr" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" +) + +var dbPool sessionmanager.Database + +type pooldb struct { + *pgxpool.Pool +} + +func (p *pooldb) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(p.Pool) +} + +func TestMain(m *testing.M) { + ctx := context.Background() + + pool, _, terminate := postgrestest.Start(ctx) + defer terminate(ctx) + + dbPool = &pooldb{pool} + + code := m.Run() + os.Exit(code) +} + +func TestRepository_Get(t *testing.T) { + tests := []struct { + name string + tenantID string + wantMapping *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Success", + tenantID: "tenant1-id", + wantMapping: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), Audiences: make([]string, 0)}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Error does not exist", + tenantID: "does-not-exist", + assertErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := sqltrust.NewRepository(dbPool) + + gotMapping, err := r.Get(t.Context(), tt.tenantID) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Get() error %v", err)) || err != nil { + assert.Zerof(t, gotMapping, "Repository.Get() extected zero value if an error is returned, got %v", gotMapping) + return + } + + if diff := cmp.Diff(tt.wantMapping, gotMapping, protocmp.Transform()); diff != "" { + t.Fatalf("mapping not equal:\n%s", diff) + } + }) + } +} + +func TestRepository_Create(t *testing.T) { + tests := []struct { + name string + mapping *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Create succeeds", + mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Duplicate", + mapping: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.Error, + }, + { + name: "Create without JWKSURI and Audiences succeeds", + mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-2.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Create without JWKSURI succeeds", + mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-3.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Create without Audiences succeeds", + mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-4.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + r := sqltrust.NewRepository(dbPool) + + // When + err := r.Create(t.Context(), tt.mapping) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Create() error %v", err)) || err != nil { + return + } + + // Then + mapping, err := r.Get(t.Context(), tt.mapping.GetTenantId()) + require.NoError(t, err) + + if diff := cmp.Diff(tt.mapping, mapping, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) + } + }) + } +} + +func TestRepository_Delete(t *testing.T) { + const tenantID = "tenant-id-delete-success" + mapping := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-delete.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + r := sqltrust.NewRepository(dbPool) + err := r.Create(t.Context(), mapping) + require.NoError(t, err, "Inserting test data") + + tests := []struct { + name string + tenantID string + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Delete tenant", + tenantID: tenantID, + assertErr: assert.NoError, + }, + { + name: "Error does not exist", + tenantID: "does-not-exist", + assertErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.Delete(t.Context(), tt.tenantID) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Delete() error %v", err)) || err != nil { + return + } + + gotMapping, err := r.Get(t.Context(), tt.tenantID) + if !errors.Is(err, serviceerr.ErrNotFound) { + t.Error("The mapping is expected to be deleted") + } + assert.Zero(t, gotMapping, "The mapping is expected to be deleted, instead a value is returned") + }) + } +} + +func TestRepository_Update(t *testing.T) { + const tenantID = "tenant-id-update-success" + mapping := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-update.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + r := sqltrust.NewRepository(dbPool) + err := r.Create(t.Context(), mapping) + require.NoError(t, err, "Inserting test data") + + tests := []struct { + name string + mapping *trustv1.Trust + assertErr assert.ErrorAssertionFunc + }{ + { + name: "Update succeeds", + mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Does not exist", + mapping: trustv1.Trust_builder{TenantId: new("does-not-exist"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new("does-not-exist"), JwksUri: new("jwks-updated.example.com"), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.Error, + }, + { + name: "Update without JWKSURI and Audiences succeeds", + mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Update without JWKSURI succeeds", + mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + assertErr: assert.NoError, + }, + { + name: "Update without Audiences succeeds", + mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: []string{}}.Build()}.Build(), + assertErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.Update(t.Context(), tt.mapping) + if !tt.assertErr(t, err, fmt.Sprintf("Repository.Update() error %v", err)) || err != nil { + return + } + + gotMapping, err := r.Get(t.Context(), tt.mapping.GetTenantId()) + require.NoError(t, err) + + if diff := cmp.Diff(tt.mapping, gotMapping, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) + } + }) + } +} + +func TestPgTextOrNull(t *testing.T) { + tests := []struct { + name string + input string + want pgtype.Text + }{ + { + name: "empty string returns invalid (null)", + input: "", + want: pgtype.Text{String: "", Valid: false}, + }, + { + name: "non-empty string returns valid text", + input: "hello", + want: pgtype.Text{String: "hello", Valid: true}, + }, + { + name: "whitespace-only string returns valid text", + input: " ", + want: pgtype.Text{String: " ", Valid: true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sqltrust.PgTextOrNull(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestHandlePgError(t *testing.T) { + otherPgErr := &pgconn.PgError{Code: "42P01"} // undefined_table + sentinel := errors.New("some other error") + + tests := []struct { + name string + err error + wantErr error + wantHandled bool + }{ + { + name: "duplicate key violation (23505) returns ErrConflict", + err: &pgconn.PgError{Code: "23505"}, + wantErr: serviceerr.ErrConflict, + wantHandled: true, + }, + { + name: "other pg error code returns original error", + err: otherPgErr, + wantErr: otherPgErr, + wantHandled: false, + }, + { + name: "non-pg error returns original error", + err: sentinel, + wantErr: sentinel, + wantHandled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, handled := sqltrust.HandlePgError(tt.err) + if handled != tt.wantHandled { + t.Errorf("handled = %v, want %v", handled, tt.wantHandled) + } + if !errors.Is(got, tt.wantErr) { + t.Errorf("err = %v, want %v", got, tt.wantErr) + } + }) + } +} diff --git a/modules/oidctrust/mapping.go b/modules/oidctrust/mapping.go new file mode 100644 index 00000000..abc020e1 --- /dev/null +++ b/modules/oidctrust/mapping.go @@ -0,0 +1,110 @@ +package oidctrust + +import ( + "context" + "errors" + "fmt" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +// ApplyMapping applies and stores the provided Trust. +func (m *TrustModule) ApplyMapping(ctx context.Context, trust *trustv1.Trust) error { + if _, err := m.repository.Get(ctx, trust.GetTenantId()); err != nil { + err = m.repository.Create(ctx, trust) + if err != nil { + return fmt.Errorf("creating mapping for tenant: %w", err) + } + } else { + err = m.repository.Update(ctx, trust) + if err != nil { + return fmt.Errorf("updating mapping for tenant: %w", err) + } + } + + return nil +} + +// BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. +// If the mapping is already blocked, it does nothing. +// Returns an error if the mapping cannot be retrieved or updated. +func (m *TrustModule) BlockMapping(ctx context.Context, tenantID string) error { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("getting mapping for tenant: %w", err) + } + if trust.GetBlocked() { + return nil + } + + trust.SetBlocked(true) + if err = m.repository.Update(ctx, trust); err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("updating mapping for blocking tenant: %w", err) + } + return nil +} + +func (m *TrustModule) RemoveMapping(ctx context.Context, tenantID string) error { + err := m.repository.Delete(ctx, tenantID) + if err != nil { + return fmt.Errorf("deleting mapping for tenant: %w", err) + } + + return nil +} + +// UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. +// If the mapping is not blocked, it does nothing. +// Returns an error if the mapping cannot be retrieved or updated. +func (m *TrustModule) UnblockMapping(ctx context.Context, tenantID string) error { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("getting mapping for tenant: %w", err) + } + if !trust.GetBlocked() { + return nil + } + trust.SetBlocked(false) + if err = m.repository.Update(ctx, trust); err != nil { + if errors.Is(err, serviceerr.ErrNotFound) { + return nil + } + return fmt.Errorf("updating mapping for unblocking tenant: %w", err) + } + return nil +} + +// Get returns a trust message with optional extensions set. +func (m *TrustModule) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + trust, err := m.repository.Get(ctx, tenantID) + if err != nil { + return nil, fmt.Errorf("getting trust from repository: %w", err) + } + + m.resolveExtensions(trust) + return trust, nil +} + +// resolveExtensions sets optional extensions to the Trust message and its details if configured. +func (m *TrustModule) resolveExtensions(trust *trustv1.Trust) { + switch trust.WhichDetails() { + case trustv1.Trust_Oidc_case: + m.resolveOIDCExtensions(trust.GetOidc()) + } +} + +// resolveOIDCExtensions sets optional extensions to the ODIC message if configured. +func (m *TrustModule) resolveOIDCExtensions(oidc *oidcv1.OIDC) { +} diff --git a/modules/oidctrust/mapping_test.go b/modules/oidctrust/mapping_test.go new file mode 100644 index 00000000..76d2b19c --- /dev/null +++ b/modules/oidctrust/mapping_test.go @@ -0,0 +1,677 @@ +package oidctrust_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/gofrs/uuid/v5" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + "github.com/openkcm/session-manager/modules/oidctrust" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" +) + +var repo oidctrust.TrustRepository + +const ( + requestURI = "http://cmk.example.com/ui" + jwksURI = "http://jwks.example.com" +) + +func TestMain(m *testing.M) { + ctx := context.Background() + r, err := createRepo(ctx) + if err != nil { + slogctx.Error(ctx, "error while creating repo", "error", err) + } + + repo = r + + code := m.Run() + os.Exit(code) +} + +func TestService_ApplyMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success if", func(t *testing.T) { + t.Run("the trust does not exist", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + err := subj.ApplyMapping(ctx, expMapping) + assert.NoError(t, err) + + actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + if diff := cmp.Diff(expMapping, actMapping, protocmp.Transform()); diff != "" { + t.Fatalf("mapping not equal:\n%s", diff) + } + }) + + t.Run("the trust exists", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + err := subj.ApplyMapping(ctx, expMapping) + assert.NoError(t, err) + + expUpdatedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(expMapping.GetOidc().GetIssuer()), + JwksUri: new("http://updated-jwks.example.com"), + Audiences: []string{requestURI, "http://new-aud.example.com"}, + }.Build(), + }.Build() + + err = subj.ApplyMapping(ctx, expUpdatedMapping) + assert.NoError(t, err) + + actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + if diff := cmp.Diff(expUpdatedMapping, actMapping, protocmp.Transform()); diff != "" { + t.Fatalf("mapping not equal:\n%s", diff) + } + }) + }) + + t.Run("should return error if", func(t *testing.T) { + t.Run("Create returns an error", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + noOfCalls := 0 + wrapper.MockCreate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + assert.Equal(t, expMapping, trust) + noOfCalls++ + return assert.AnError + } + + subj := oidctrust.NewModule(wrapper) + err := subj.ApplyMapping(ctx, expMapping) + + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfCalls) + }) + + t.Run("Update returns an error", func(t *testing.T) { + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + noOfCalls := 0 + wrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + assert.Equal(t, expMapping, trust) + noOfCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(wrapper) + + err := subj.ApplyMapping(ctx, expMapping) + assert.NoError(t, err) + err = subj.ApplyMapping(ctx, expMapping) + + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfCalls) + }) + }) +} + +func TestService_BlockMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success if ", func(t *testing.T) { + t.Run("the trust is unblocked", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expUnblockedMapping) + require.NoError(t, err) + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.BlockMapping(ctx, expTenantID) + + // then + assert.NoError(t, err) + + actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.True(t, actMapping.GetBlocked()) + assert.Equal(t, expUnblockedMapping.GetOidc().GetIssuer(), actMapping.GetOidc().GetIssuer()) + assert.Equal(t, expUnblockedMapping.GetOidc().GetAudiences(), actMapping.GetOidc().GetAudiences()) + assert.Equal(t, expUnblockedMapping.GetOidc().GetJwksUri(), actMapping.GetOidc().GetJwksUri()) + }) + + t.Run("the trust is blocked then it should not call Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.BlockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 0, noOfUpdateCalls) + + actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.Equal(t, expBlockedMapping, actMapping) + }) + t.Run("the trust is not found during the Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + // delete the trust before updating to return an error + err := repoWrapper.Repo.Delete(ctx, expTenantID) + assert.NoError(t, err) + return nil + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.BlockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 1, noOfUpdateCalls) + }) + t.Run("the trust is not found", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.BlockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + }) + }) + + t.Run("should return error", func(t *testing.T) { + t.Run("if Get returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + noOfGetCalls := 0 + repoWrapper.MockGet = func(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + assert.Equal(t, expTenantID, tenantID) + noOfGetCalls++ + return trustv1.Trust_builder{ + Oidc: oidcv1.OIDC_builder{}.Build(), + }.Build(), assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.BlockMapping(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfGetCalls) + }) + + t.Run("if Update returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantID, trust.GetTenantId()) + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.BlockMapping(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfUpdateCalls) + + actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.Equal(t, expMapping, actMapping) + }) + }) +} + +func TestService_UnblockMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success if ", func(t *testing.T) { + t.Run("the trust is blocked", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expBlockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expBlockedMapping) + require.NoError(t, err) + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.UnblockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + + actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.False(t, actMapping.GetBlocked()) + assert.Equal(t, expBlockedMapping.GetOidc().GetIssuer(), actMapping.GetOidc().GetIssuer()) + assert.Equal(t, expBlockedMapping.GetOidc().GetAudiences(), actMapping.GetOidc().GetAudiences()) + assert.Equal(t, expBlockedMapping.GetOidc().GetJwksUri(), actMapping.GetOidc().GetJwksUri()) + }) + + t.Run("the trust is unblocked then it should not call Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expUnblockedMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.UnblockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 0, noOfUpdateCalls) + + actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + assert.NoError(t, err) + assert.False(t, actMapping.GetBlocked()) + }) + t.Run("the trust is not found during the Update", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expUnblockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expUnblockedMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + noOfUpdateCalls++ + // delete the trust before updating to return an error + err := repoWrapper.Repo.Delete(ctx, expTenantID) + assert.NoError(t, err) + return nil + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.UnblockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + assert.Equal(t, 1, noOfUpdateCalls) + }) + t.Run("the trust is not found", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + repoWrapper := &RepoWrapper{Repo: repo} + + subj := oidctrust.NewModule(repoWrapper) + + // when + err := subj.UnblockMapping(t.Context(), expTenantID) + + // then + assert.NoError(t, err) + }) + }) + t.Run("should return error", func(t *testing.T) { + t.Run("if Get returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + mockRepo := &RepoWrapper{Repo: repo} + + noOfGetTenantCalls := 0 + mockRepo.MockGet = func(ctx context.Context, tenantID string) (*trustv1.Trust, error) { + assert.Equal(t, expTenantID, tenantID) + noOfGetTenantCalls++ + return trustv1.Trust_builder{ + Oidc: oidcv1.OIDC_builder{}.Build(), + }.Build(), assert.AnError + } + subj := oidctrust.NewModule(mockRepo) + + // when + err := subj.UnblockMapping(t.Context(), expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfGetTenantCalls) + }) + + t.Run("if Update returns an error", func(t *testing.T) { + // given + expTenantIDtoUpdate := uuid.Must(uuid.NewV4()).String() + expBlockedMapping := trustv1.Trust_builder{ + TenantId: new(expTenantIDtoUpdate), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + repoWrapper := &RepoWrapper{Repo: repo} + err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + require.NoError(t, err) + + noOfUpdateCalls := 0 + repoWrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { + assert.Equal(t, expTenantIDtoUpdate, trust.GetTenantId()) + noOfUpdateCalls++ + return assert.AnError + } + subj := oidctrust.NewModule(repoWrapper) + + // when + err = subj.UnblockMapping(t.Context(), expTenantIDtoUpdate) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfUpdateCalls) + + actMapping, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) + assert.NoError(t, err) + assert.Equal(t, expBlockedMapping, actMapping) + }) + }) +} + +func TestService_RemoveMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success if", func(t *testing.T) { + t.Run("the trust exists", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + expMapping := trustv1.Trust_builder{ + TenantId: new(expTenantID), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build() + + wrapper := &RepoWrapper{Repo: repo} + err := wrapper.Repo.Create(ctx, expMapping) + require.NoError(t, err) + + subj := oidctrust.NewModule(wrapper) + + // when + err = subj.RemoveMapping(ctx, expTenantID) + + // then + assert.NoError(t, err) + + // verify the trust was deleted + _, err = wrapper.Repo.Get(ctx, expTenantID) + assert.Error(t, err) + }) + }) + + t.Run("should return error if", func(t *testing.T) { + t.Run("the trust does not exist", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + wrapper := &RepoWrapper{Repo: repo} + subj := oidctrust.NewModule(wrapper) + + // when + err := subj.RemoveMapping(ctx, expTenantID) + + // then + assert.Error(t, err) + }) + + t.Run("Delete returns an error", func(t *testing.T) { + // given + expTenantID := uuid.Must(uuid.NewV4()).String() + wrapper := &RepoWrapper{Repo: repo} + + noOfDeleteCalls := 0 + wrapper.MockDelete = func(ctx context.Context, tenantID string) error { + assert.Equal(t, expTenantID, tenantID) + noOfDeleteCalls++ + return assert.AnError + } + + subj := oidctrust.NewModule(wrapper) + + // when + err := subj.RemoveMapping(ctx, expTenantID) + + // then + assert.ErrorIs(t, err, assert.AnError) + assert.Equal(t, 1, noOfDeleteCalls) + }) + }) +} + +func TestService_Get(t *testing.T) { + ctx := t.Context() + + repoErr := errors.New("repository error") + + tests := []struct { + name string + trust *trustv1.Trust + repoErr error + wantErr bool + wantErrIs error + }{ + { + name: "returns trust", + trust: trustv1.Trust_builder{ + TenantId: new(uuid.Must(uuid.NewV4()).String()), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new(uuid.Must(uuid.NewV4()).String()), + JwksUri: new(jwksURI), + Audiences: []string{requestURI}, + }.Build(), + }.Build(), + }, + { + name: "returns error when trust does not exist", + wantErr: true, + }, + { + name: "wraps repository error", + repoErr: repoErr, + wantErr: true, + wantErrIs: repoErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var opts []mocktrust.RepositoryOption + if tt.trust != nil { + opts = append(opts, mocktrust.WithTrust(tt.trust)) + } + if tt.repoErr != nil { + opts = append(opts, mocktrust.WithGetError(tt.repoErr)) + } + + mockRepo := mocktrust.NewInMemRepository(opts...) + subj := oidctrust.NewModule(mockRepo) + + var tenantID string + if tt.trust != nil { + tenantID = tt.trust.GetTenantId() + } else { + tenantID = uuid.Must(uuid.NewV4()).String() + } + + got, err := subj.Get(ctx, tenantID) + + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Fatalf("error = %v, want to wrap %v", err, tt.wantErrIs) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.trust, got, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) + } + }) + } +} diff --git a/sql/00001_init.sql b/modules/oidctrust/migrations/00001_init.sql similarity index 100% rename from sql/00001_init.sql rename to modules/oidctrust/migrations/00001_init.sql diff --git a/sql/00002_add_properties_to_providers.sql b/modules/oidctrust/migrations/00002_add_properties_to_providers.sql similarity index 100% rename from sql/00002_add_properties_to_providers.sql rename to modules/oidctrust/migrations/00002_add_properties_to_providers.sql diff --git a/sql/00003_tenant_trust.sql b/modules/oidctrust/migrations/00003_tenant_trust.sql similarity index 100% rename from sql/00003_tenant_trust.sql rename to modules/oidctrust/migrations/00003_tenant_trust.sql diff --git a/sql/00004_single_tenant.sql b/modules/oidctrust/migrations/00004_single_tenant.sql similarity index 100% rename from sql/00004_single_tenant.sql rename to modules/oidctrust/migrations/00004_single_tenant.sql diff --git a/modules/oidctrust/migrations/00005_remove_properties.sql b/modules/oidctrust/migrations/00005_remove_properties.sql new file mode 100644 index 00000000..76d2041f --- /dev/null +++ b/modules/oidctrust/migrations/00005_remove_properties.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE trust + DROP COLUMN properties; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE trust + ADD COLUMN properties JSONB NOT NULL; +-- +goose StatementEnd diff --git a/modules/oidctrust/migrations/migration.go b/modules/oidctrust/migrations/migration.go new file mode 100644 index 00000000..e378c782 --- /dev/null +++ b/modules/oidctrust/migrations/migration.go @@ -0,0 +1,62 @@ +package migrations + +import ( + "context" + "embed" + "fmt" + + "github.com/pressly/goose/v3" + + sessionmanager "github.com/openkcm/session-manager" +) + +//go:embed *.sql +var FS embed.FS + +const moduleID = "trust.migration.module.oidc" + +func newModule() sessionmanager.Module { + return new(MigrationModule) +} + +func init() { + sessionmanager.RegisterModule(new(MigrationModule)) +} + +type MigrationModule struct { + DBModule string `yaml:"dbModule" default:"database.module.pgxpool"` + + db sessionmanager.Database +} + +func (m *MigrationModule) Migrate(ctx context.Context) error { + goose.SetBaseFS(FS) + + if err := goose.SetDialect("postgres"); err != nil { + return fmt.Errorf("setting goose dialect: %w", err) + } + + if err := goose.UpContext(ctx, m.db.STDAdapter(), "."); err != nil { + return fmt.Errorf("applying migrations: %w", err) + } + + return nil +} + +func (m *MigrationModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *MigrationModule) Provision(ctx *sessionmanager.Context) error { + mod, err := ctx.GetModule(m.DBModule) + if err != nil { + return fmt.Errorf("getting postgres module: %w", err) + } + + //nolint:forcetypeassert + m.db = mod.(sessionmanager.Database) + return nil +} diff --git a/internal/trust/trustmock/repository.go b/modules/oidctrust/mocks/repository.go similarity index 56% rename from internal/trust/trustmock/repository.go rename to modules/oidctrust/mocks/repository.go index d64d6cca..d200b8eb 100644 --- a/internal/trust/trustmock/repository.go +++ b/modules/oidctrust/mocks/repository.go @@ -1,22 +1,24 @@ -package trustmock +package mocktrust import ( "context" - "github.com/openkcm/session-manager/internal/serviceerr" - "github.com/openkcm/session-manager/internal/trust" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/pkg/serviceerr" + "github.com/openkcm/session-manager/modules/oidctrust" ) type RepositoryOption func(*Repository) type Repository struct { - tenantTrust map[string]trust.OIDCMapping + tenantTrust map[string]*trustv1.Trust getErr, createErr, deleteErr, updateErr error } -func WithTrust(tenantID string, mapping trust.OIDCMapping) RepositoryOption { - return func(r *Repository) { r.tenantTrust[tenantID] = mapping } +func WithTrust(mapping *trustv1.Trust) RepositoryOption { + return func(r *Repository) { r.tenantTrust[mapping.GetTenantId()] = mapping } } func WithGetError(err error) RepositoryOption { return func(r *Repository) { r.getErr = err } @@ -31,11 +33,11 @@ func WithUpdateError(err error) RepositoryOption { return func(r *Repository) { r.updateErr = err } } -var _ = trust.OIDCMappingRepository(&Repository{}) +var _ oidctrust.TrustRepository = (*Repository)(nil) func NewInMemRepository(opts ...RepositoryOption) *Repository { r := &Repository{ - tenantTrust: make(map[string]trust.OIDCMapping), + tenantTrust: make(map[string]*trustv1.Trust), } for _, opt := range opts { if opt != nil { @@ -46,30 +48,30 @@ func NewInMemRepository(opts ...RepositoryOption) *Repository { } // TAdd is a helper method for tests to add a trust relationship. -func (r *Repository) TAdd(tenantID string, mapping trust.OIDCMapping) { - r.tenantTrust[tenantID] = mapping +func (r *Repository) TAdd(mapping *trustv1.Trust) { + r.tenantTrust[mapping.GetTenantId()] = mapping } // TGet is a helper method for tests to get a trust relationship. -func (r *Repository) TGet(tenantID string) trust.OIDCMapping { +func (r *Repository) TGet(tenantID string) *trustv1.Trust { return r.tenantTrust[tenantID] } -func (r *Repository) Get(_ context.Context, tenantID string) (trust.OIDCMapping, error) { +func (r *Repository) Get(_ context.Context, tenantID string) (*trustv1.Trust, error) { if r.getErr != nil { - return trust.OIDCMapping{}, r.getErr + return nil, r.getErr } if mapping, ok := r.tenantTrust[tenantID]; ok { return mapping, nil } - return trust.OIDCMapping{}, serviceerr.ErrNotFound + return nil, serviceerr.ErrNotFound } -func (r *Repository) Create(_ context.Context, tenantID string, mapping trust.OIDCMapping) error { +func (r *Repository) Create(_ context.Context, mapping *trustv1.Trust) error { if r.createErr != nil { return r.createErr } - r.tenantTrust[tenantID] = mapping + r.tenantTrust[mapping.GetTenantId()] = mapping return nil } @@ -84,10 +86,10 @@ func (r *Repository) Delete(_ context.Context, tenantID string) error { return nil } -func (r *Repository) Update(_ context.Context, tenantID string, mapping trust.OIDCMapping) error { +func (r *Repository) Update(_ context.Context, mapping *trustv1.Trust) error { if r.updateErr != nil { return r.updateErr } - r.tenantTrust[tenantID] = mapping + r.tenantTrust[mapping.GetTenantId()] = mapping return nil } diff --git a/modules/oidctrust/module.go b/modules/oidctrust/module.go new file mode 100644 index 00000000..010575fa --- /dev/null +++ b/modules/oidctrust/module.go @@ -0,0 +1,48 @@ +package oidctrust + +import ( + "fmt" + + sessionmanager "github.com/openkcm/session-manager" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" +) + +const moduleID = "trust.module.oidc" + +func newModule() sessionmanager.Module { + return new(TrustModule) +} + +func init() { + sessionmanager.RegisterModule(new(TrustModule)) +} + +// TrustModule is a module that implements sessionmanager.Trust interface. It's using a database providede by the +// [dbModule] module which implements sessionmanager.DBModule. +type TrustModule struct { + DBModule string `yaml:"dbModule" default:"database.module.pgxpool"` + + repository TrustRepository +} + +func (m *TrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *TrustModule) Provision(ctx *sessionmanager.Context) error { + dbMod, err := ctx.GetModule(m.DBModule) + if err != nil { + return fmt.Errorf("getting db module: %w", err) + } + + //nolint:forcetypeassert + db := dbMod.(sessionmanager.Database) + m.repository = sqltrust.NewRepository(db) + + return nil +} + +var _ sessionmanager.Trust = (*TrustModule)(nil) diff --git a/modules/oidctrust/repository.go b/modules/oidctrust/repository.go new file mode 100644 index 00000000..a114b37a --- /dev/null +++ b/modules/oidctrust/repository.go @@ -0,0 +1,15 @@ +package oidctrust + +import ( + "context" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" +) + +// TrustRepository allows to read OIDC mapping data for a tenant stored in the context. +type TrustRepository interface { + Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) + Create(ctx context.Context, trust *trustv1.Trust) error + Delete(ctx context.Context, tenantID string) error + Update(ctx context.Context, trust *trustv1.Trust) error +} diff --git a/internal/trust/repository_test.go b/modules/oidctrust/repository_test.go similarity index 57% rename from internal/trust/repository_test.go rename to modules/oidctrust/repository_test.go index 549e42d3..0155db24 100644 --- a/internal/trust/repository_test.go +++ b/modules/oidctrust/repository_test.go @@ -1,4 +1,4 @@ -package trust_test +package oidctrust_test import ( "context" @@ -6,14 +6,15 @@ import ( "fmt" "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "github.com/pressly/goose/v3" "github.com/testcontainers/testcontainers-go/modules/postgres" - _ "github.com/jackc/pgx/v5/stdlib" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/internal/trust" - "github.com/openkcm/session-manager/internal/trust/trustsql" - migrations "github.com/openkcm/session-manager/sql" + "github.com/openkcm/session-manager/modules/oidctrust" + sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" + migrations "github.com/openkcm/session-manager/modules/oidctrust/migrations" ) const ( @@ -25,28 +26,28 @@ const ( ) type RepoWrapper struct { - Repo trust.OIDCMappingRepository - MockGet func(ctx context.Context, tenantID string) (trust.OIDCMapping, error) - MockCreate func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error - MockUpdate func(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error + Repo oidctrust.TrustRepository + MockGet func(ctx context.Context, tenantID string) (*trustv1.Trust, error) + MockCreate func(ctx context.Context, trust *trustv1.Trust) error MockDelete func(ctx context.Context, tenantID string) error + MockUpdate func(ctx context.Context, trust *trustv1.Trust) error } -var _ trust.OIDCMappingRepository = &RepoWrapper{} +var _ oidctrust.TrustRepository = &RepoWrapper{} -// Create implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Create(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { +// Create implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Create(ctx context.Context, trust *trustv1.Trust) error { if m.MockCreate != nil { - err := m.MockCreate(ctx, tenantID, mapping) + err := m.MockCreate(ctx, trust) if err != nil { return err } } - return m.Repo.Create(ctx, tenantID, mapping) + return m.Repo.Create(ctx, trust) } -// Delete implements oidc.OIDCMappingRepository. +// Delete implements oidc.OIDCTrustRepository. func (m *RepoWrapper) Delete(ctx context.Context, tenantID string) error { if m.MockDelete != nil { err := m.MockDelete(ctx, tenantID) @@ -58,29 +59,29 @@ func (m *RepoWrapper) Delete(ctx context.Context, tenantID string) error { return m.Repo.Delete(ctx, tenantID) } -// Get implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Get(ctx context.Context, tenantID string) (trust.OIDCMapping, error) { +// Get implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { if m.MockGet != nil { _, err := m.MockGet(ctx, tenantID) if err != nil { - return trust.OIDCMapping{}, err + return nil, err } } return m.Repo.Get(ctx, tenantID) } -// Update implements oidc.OIDCMappingRepository. -func (m *RepoWrapper) Update(ctx context.Context, tenantID string, mapping trust.OIDCMapping) error { +// Update implements oidc.OIDCTrustRepository. +func (m *RepoWrapper) Update(ctx context.Context, trust *trustv1.Trust) error { if m.MockUpdate != nil { - err := m.MockUpdate(ctx, tenantID, mapping) + err := m.MockUpdate(ctx, trust) if err != nil { return err } } - return m.Repo.Update(ctx, tenantID, mapping) + return m.Repo.Update(ctx, trust) } -func createRepo(ctx context.Context) (trust.OIDCMappingRepository, error) { +func createRepo(ctx context.Context) (oidctrust.TrustRepository, error) { pgContainer, err := postgres.Run( ctx, "postgres:17-alpine", @@ -110,7 +111,7 @@ func createRepo(ctx context.Context) (trust.OIDCMappingRepository, error) { return nil, err } - return trustsql.NewRepository(dbPool), nil + return sqltrust.NewRepository(&dbWrapper{dbPool}), nil } func migrateDB(ctx context.Context, connStr string) error { @@ -133,3 +134,11 @@ func migrateDB(ctx context.Context, connStr string) error { } return nil } + +type dbWrapper struct { + *pgxpool.Pool +} + +func (w *dbWrapper) STDAdapter() *sql.DB { + return stdlib.OpenDBFromPool(w.Pool) +} diff --git a/modules/standard/imports.go b/modules/standard/imports.go new file mode 100644 index 00000000..d21957ef --- /dev/null +++ b/modules/standard/imports.go @@ -0,0 +1,7 @@ +package standard + +import ( + _ "github.com/openkcm/session-manager/modules/database/pgxpool" + _ "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/oidctrust/migrations" +) diff --git a/modules_test.go b/modules_test.go new file mode 100644 index 00000000..c5f2b4cc --- /dev/null +++ b/modules_test.go @@ -0,0 +1,82 @@ +package sessionmanager_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// stubModule is a minimal Module used across tests in this file. +type stubModule struct{ id string } + +func (s *stubModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return &stubModule{id: s.id} }, + } +} + +// newRegistry resets the global module registry by registering a fresh set of +// modules under unique IDs so parallel/serial tests don't interfere with each other. +// It returns a cleanup function that can be deferred. +// +// Because the global map is package-level state, each test that touches the +// registry must use IDs that are unique within the whole test binary run. +func uniqueID(t *testing.T, suffix string) string { + t.Helper() + return t.Name() + "/" + suffix +} + +func TestRegisterModule_Success(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + info, err := sessionmanager.GetModule(id) + require.NoError(t, err) + assert.Equal(t, id, info.ID) +} + +func TestRegisterModule_DuplicatePanics(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + assert.Panics(t, func() { + sessionmanager.RegisterModule(&stubModule{id: id}) + }) +} + +func TestGetModule_NotRegistered(t *testing.T) { + _, err := sessionmanager.GetModule("module-that-does-not-exist") + require.Error(t, err) + assert.Contains(t, err.Error(), "not registered") +} + +func TestModules_ContainsRegistered(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + found := false + for info := range sessionmanager.Modules() { + if info.ID == id { + found = true + break + } + } + assert.True(t, found, "registered module should appear in Modules()") +} + +func TestModuleInfo_New(t *testing.T) { + id := uniqueID(t, "mod") + sessionmanager.RegisterModule(&stubModule{id: id}) + + info, err := sessionmanager.GetModule(id) + require.NoError(t, err) + require.NotNil(t, info.New) + + instance := info.New() + require.NotNil(t, instance) + assert.Equal(t, id, instance.Module().ID) +} diff --git a/internal/serviceerr/errors.go b/pkg/serviceerr/errors.go similarity index 100% rename from internal/serviceerr/errors.go rename to pkg/serviceerr/errors.go diff --git a/internal/serviceerr/errors_test.go b/pkg/serviceerr/errors_test.go similarity index 99% rename from internal/serviceerr/errors_test.go rename to pkg/serviceerr/errors_test.go index 72c055c9..92371e7a 100644 --- a/internal/serviceerr/errors_test.go +++ b/pkg/serviceerr/errors_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/openkcm/session-manager/internal/serviceerr" + "github.com/openkcm/session-manager/pkg/serviceerr" ) func TestError_Error(t *testing.T) { diff --git a/sql/fs.go b/sql/fs.go deleted file mode 100644 index 91cca1c3..00000000 --- a/sql/fs.go +++ /dev/null @@ -1,6 +0,0 @@ -package migrations - -import "embed" - -//go:embed *.sql -var FS embed.FS diff --git a/sqlc.yaml b/sqlc.yaml index f1f46504..78319835 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -1,12 +1,12 @@ version: "2" sql: - engine: postgresql - queries: ./internal/trust/trustsql/queries.sql - schema: ./sql + queries: ./modules/oidctrust/internal/sql/queries.sql + schema: ./modules/oidctrust/migrations gen: go: package: queries - out: ./internal/trust/trustsql/internal/queries + out: ./modules/oidctrust/internal/sql/queries sql_package: pgx/v5 sql_driver: github.com/jackc/pgx/v5 emit_db_tags: true diff --git a/trust.go b/trust.go new file mode 100644 index 00000000..4ed79335 --- /dev/null +++ b/trust.go @@ -0,0 +1,15 @@ +package sessionmanager + +import ( + "context" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" +) + +type Trust interface { + ApplyMapping(ctx context.Context, trust *trustv1.Trust) error + BlockMapping(ctx context.Context, tenantID string) error + RemoveMapping(ctx context.Context, tenantID string) error + UnblockMapping(ctx context.Context, tenantID string) error + Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) +} From 480a8eb095b7c6dde2367a343ce16799b658bb0a Mon Sep 17 00:00:00 2001 From: Danylo Shevchenko Date: Tue, 26 May 2026 08:59:48 +0200 Subject: [PATCH 2/5] feat: application lifecycle Signed-off-by: Danylo Shevchenko --- context.go | 69 ++++- context_test.go | 121 ++++++++ go.mod | 2 +- go.sum | 4 +- internal/business/apps.go | 103 +++++++ internal/business/apps_test.go | 198 +++++++++++++ internal/business/business.go | 129 +-------- internal/business/business_test.go | 270 ------------------ internal/business/housekeeper.go | 3 +- internal/business/housekeeper_test.go | 2 +- internal/config/config.go | 33 +++ internal/config/config_test.go | 72 +++++ internal/config/context.go | 26 ++ internal/config/load.go | 17 +- .../{oidcmapping_test.go => trust_test.go} | 0 internal/sessionwiring/sessionwiring.go | 130 +++++++++ internal/sessionwiring/sessionwiring_test.go | 147 ++++++++++ modules.go | 5 + modules/oidctrust/{mapping.go => trust.go} | 13 +- .../{mapping_test.go => trust_test.go} | 0 trust.go | 9 + 21 files changed, 947 insertions(+), 406 deletions(-) create mode 100644 internal/business/apps.go create mode 100644 internal/business/apps_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/config/context.go rename internal/grpc/{oidcmapping_test.go => trust_test.go} (100%) create mode 100644 internal/sessionwiring/sessionwiring.go create mode 100644 internal/sessionwiring/sessionwiring_test.go rename modules/oidctrust/{mapping.go => trust.go} (82%) rename modules/oidctrust/{mapping_test.go => trust_test.go} (100%) diff --git a/context.go b/context.go index fe70fa32..9e150bf9 100644 --- a/context.go +++ b/context.go @@ -15,12 +15,14 @@ type Context struct { context.Context mods map[string]Module + apps map[string]App } func (c *Context) cloneWithParent(parent context.Context) *Context { return &Context{ Context: parent, mods: c.mods, + apps: c.apps, } } @@ -30,7 +32,11 @@ func (c *Context) WithValue(key, val any) *Context { func NewContext(ctx context.Context) (*Context, context.CancelCauseFunc) { ctx, cancelCause := context.WithCancelCause(ctx) - c := &Context{Context: ctx, mods: make(map[string]Module)} + c := &Context{ + Context: ctx, + mods: make(map[string]Module), + apps: make(map[string]App), + } return c, func(cause error) { cancelCause(cause) for name, mod := range c.mods { @@ -40,6 +46,13 @@ func NewContext(ctx context.Context) (*Context, context.CancelCauseFunc) { } } } + for name, app := range c.apps { + if closer, ok := app.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close an app", "app", name, "error", err) + } + } + } } } @@ -56,23 +69,65 @@ func (c *Context) GetModule(id string) (Module, error) { return nil, errors.New("module is not loaded") } +func (c *Context) GetApp(id string) (App, error) { + if app, ok := c.apps[id]; ok { + return app, nil + } + + return nil, errors.New("app is not loaded") +} + func (c *Context) LoadModule(cfg ExtensionConfig) (Module, error) { - modInfo, err := GetModule(cfg.Module()) + mod, modInfo, err := c.instantiate(cfg) if err != nil { - return nil, fmt.Errorf("getting module %q: %w", reflect.TypeOf(cfg), err) + return nil, err } if _, ok := c.mods[modInfo.ID]; ok { return nil, errors.New("module has already been loaded") } + c.mods[modInfo.ID] = mod + + return mod, nil +} + +func (c *Context) LoadApp(cfg ExtensionConfig) (App, error) { + mod, modInfo, err := c.instantiate(cfg) + if err != nil { + return nil, err + } + + app, ok := mod.(App) + if !ok { + return nil, fmt.Errorf("module %q does not implement the App interface", modInfo.ID) + } + + if _, ok := c.apps[modInfo.ID]; ok { + return nil, errors.New("app has already been loaded") + } + + c.apps[modInfo.ID] = app + + return app, nil +} + +// instantiate resolves cfg.Module(), calls New(), unmarshals the extension, and +// runs Provision if the resulting instance is a Provisioner. It is shared by +// LoadModule and LoadApp. +func (c *Context) instantiate(cfg ExtensionConfig) (Module, ModuleInfo, error) { + modInfo, err := GetModule(cfg.Module()) + if err != nil { + return nil, ModuleInfo{}, fmt.Errorf("getting module %q: %w", reflect.TypeOf(cfg), err) + } + slogctx.Debug(c, "loading module", "module", modInfo.ID) mod := modInfo.New() rv := reflect.ValueOf(mod) if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Struct { if err := cfg.UnmarshalExtension(mod); err != nil { - return nil, fmt.Errorf("unmarshaling extension %s: %w", modInfo.ID, err) + return nil, ModuleInfo{}, fmt.Errorf("unmarshaling extension %s: %w", modInfo.ID, err) } } @@ -80,13 +135,11 @@ func (c *Context) LoadModule(cfg ExtensionConfig) (Module, error) { if provisioner, ok := mod.(Provisioner); ok { if err := provisioner.Provision(c); err != nil { - return nil, fmt.Errorf("provisioning module: %w", err) + return nil, ModuleInfo{}, fmt.Errorf("provisioning module: %w", err) } slogctx.Debug(c, "provisioned module", "module", modInfo.ID) } - c.mods[modInfo.ID] = mod - - return mod, nil + return mod, modInfo, nil } diff --git a/context_test.go b/context_test.go index 745d25e7..c3ba86ca 100644 --- a/context_test.go +++ b/context_test.go @@ -234,6 +234,124 @@ func TestGetModule_NotLoaded(t *testing.T) { assert.Contains(t, err.Error(), "not loaded") } +// appModule is a Module that also satisfies the App interface. +type appModule struct { + stubModule + + started bool + stopped bool +} + +func (a *appModule) Start() error { + a.started = true + return nil +} + +func (a *appModule) Stop() error { + a.stopped = true + return nil +} + +// closableAppModule is an App that also satisfies io.Closer. +type closableAppModule struct { + appModule + + closed bool +} + +func (a *closableAppModule) Close() error { + a.closed = true + return nil +} + +func TestLoadApp_Success(t *testing.T) { + id := uniqueID(t, "app") + am := &appModule{stubModule: stubModule{id: id}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return am }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + app, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + require.NotNil(t, app) + + got, err := ctx.GetApp(id) + require.NoError(t, err) + assert.Same(t, app, got) +} + +func TestLoadApp_MissingAppInterface(t *testing.T) { + id := uniqueID(t, "notapp") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &stubModule{id: id} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "App interface") +} + +func TestLoadApp_UnknownModule(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: "no-such-app-module"}) + require.Error(t, err) +} + +func TestLoadApp_DuplicateReturnsError(t *testing.T) { + id := uniqueID(t, "dupapp") + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &appModule{stubModule: stubModule{id: id}} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + _, err = ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.Error(t, err) + assert.Contains(t, err.Error(), "already been loaded") +} + +func TestGetApp_NotLoaded(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.GetApp("never-loaded") + require.Error(t, err) + assert.Contains(t, err.Error(), "not loaded") +} + +func TestNewContext_CancelClosesApps(t *testing.T) { + id := uniqueID(t, "closableapp") + cam := &closableAppModule{appModule: appModule{stubModule: stubModule{id: id}}} + + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return cam }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + + cancel(nil) + assert.True(t, cam.closed, "Close() should be called on apps when context is cancelled") +} + // Ensure stubModule satisfies the Module interface at compile time. var _ sessionmanager.Module = (*stubModule)(nil) @@ -242,3 +360,6 @@ var _ sessionmanager.Provisioner = (*provisionableModule)(nil) // Ensure closableModule satisfies io.Closer at compile time. var _ io.Closer = (*closableModule)(nil) + +// Ensure appModule satisfies App at compile time. +var _ sessionmanager.App = (*appModule)(nil) diff --git a/go.mod b/go.mod index ec1e57be..1661dec5 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/knadh/koanf/v2 v2.3.4 github.com/moby/moby/api v1.54.2 github.com/oapi-codegen/runtime v1.4.0 - github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca + github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096 github.com/openkcm/common-sdk v1.16.0 github.com/pressly/goose/v3 v3.27.1 github.com/samber/oops v1.21.0 diff --git a/go.sum b/go.sum index 2fcf5957..8e0f2f20 100644 --- a/go.sum +++ b/go.sum @@ -272,8 +272,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca h1:XsZ2DB1EpXnT5XQaHT29acprdvH1kRFXsE/FHVisHx4= -github.com/openkcm/api-sdk v0.17.1-0.20260518093831-a872a7e182ca/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= +github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096 h1:k814id04b74JgTxdKzQl2+9+Th+jIzvoc3sd0tYilRE= +github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= github.com/openkcm/common-sdk v1.16.0 h1:pmLXRHvjqg+8ATEyzXarCRiRghw/8pXGn2OtoYuMEIU= github.com/openkcm/common-sdk v1.16.0/go.mod h1:4umveCyatAaTi6dSQgwaBg1O/wqHr4sjzuMIQhEuX1o= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= diff --git a/internal/business/apps.go b/internal/business/apps.go new file mode 100644 index 00000000..45313a2a --- /dev/null +++ b/internal/business/apps.go @@ -0,0 +1,103 @@ +package business + +import ( + "errors" + "fmt" + "slices" + + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +// startedApp pairs a configured name with its loaded handle so we can stop +// apps in reverse start order and emit meaningful logs. +type startedApp struct { + name string + app sessionmanager.App +} + +// startApps loads, provisions and starts every app declared under the +// top-level apps: section in cfg, in the order given by cfg.AppsOrder (with +// any apps not listed there appended in cfg.Apps map iteration order). +// +// If any Start() returns a non-nil error, every previously-started app is +// stopped in reverse order before the original error is returned. On +// success, the returned stopAll closure stops every started app in reverse +// order and joins any Stop() errors via errors.Join. +func startApps(ctx *sessionmanager.Context, cfg *config.Config) (stopAll func() error, _ error) { + order, err := appsStartOrder(cfg) + if err != nil { + return nil, err + } + + started := make([]startedApp, 0, len(order)) + + rollback := func() { + for _, sa := range slices.Backward(started) { + if stopErr := sa.app.Stop(); stopErr != nil { + slogctx.Error(ctx, "stopping app during rollback", "app", sa.name, "error", stopErr) + } + } + } + + for _, name := range order { + appCfg := cfg.Apps[name] + + slogctx.Info(ctx, "loading app", "app", name, "module", appCfg.Module()) + app, err := ctx.LoadApp(appCfg) + if err != nil { + rollback() + return nil, fmt.Errorf("loading app %q: %w", name, err) + } + + slogctx.Info(ctx, "starting app", "app", name) + if err := app.Start(); err != nil { + rollback() + return nil, fmt.Errorf("starting app %q: %w", name, err) + } + + started = append(started, startedApp{name: name, app: app}) + } + + return func() error { + var errs []error + for _, sa := range slices.Backward(started) { + slogctx.Info(ctx, "stopping app", "app", sa.name) + if err := sa.app.Stop(); err != nil { + slogctx.Error(ctx, "stopping app", "app", sa.name, "error", err) + errs = append(errs, fmt.Errorf("stopping app %q: %w", sa.name, err)) + } + } + return errors.Join(errs...) + }, nil +} + +// appsStartOrder returns the names of configured apps in start order. Names +// listed in cfg.AppsOrder come first, in the given order; the remainder are +// appended in cfg.Apps map iteration order. Names in AppsOrder that are not +// present in cfg.Apps surface as a configuration error. +func appsStartOrder(cfg *config.Config) ([]string, error) { + seen := make(map[string]bool, len(cfg.Apps)) + order := make([]string, 0, len(cfg.Apps)) + + for _, name := range cfg.AppsOrder { + if _, ok := cfg.Apps[name]; !ok { + return nil, fmt.Errorf("appsOrder references unknown app %q", name) + } + if !seen[name] { + seen[name] = true + order = append(order, name) + } + } + + for name := range cfg.Apps { + if !seen[name] { + seen[name] = true + order = append(order, name) + } + } + + return order, nil +} diff --git a/internal/business/apps_test.go b/internal/business/apps_test.go new file mode 100644 index 00000000..bd46afba --- /dev/null +++ b/internal/business/apps_test.go @@ -0,0 +1,198 @@ +package business + +import ( + "errors" + "fmt" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +// fakeApp is an App whose Start/Stop record their relative order across all +// fakeApp instances sharing a counter, so tests can assert ordering. +type fakeApp struct { + id string + counter *atomic.Int64 + startErr error + stopErr error + startOrder int64 + stopOrder int64 + startCalled bool + stopCalled bool +} + +func (a *fakeApp) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: a.id, + New: func() sessionmanager.Module { return a }, + } +} + +func (a *fakeApp) Start() error { + a.startCalled = true + a.startOrder = a.counter.Add(1) + return a.startErr +} + +func (a *fakeApp) Stop() error { + a.stopCalled = true + a.stopOrder = a.counter.Add(1) + return a.stopErr +} + +// registerFakeApps registers the given apps under unique module IDs scoped to +// the test name and returns a config.Config wired to load them in the order +// supplied. The returned cfg.AppsOrder pins the start order so tests don't +// rely on Go's map iteration order. +func registerFakeApps(t *testing.T, apps ...*fakeApp) *config.Config { + t.Helper() + + cfg := &config.Config{ + Apps: make(map[string]*config.App, len(apps)), + AppsOrder: make([]string, 0, len(apps)), + } + + for i, app := range apps { + modID := fmt.Sprintf("test.app.%s.%d", t.Name(), i) + app.id = modID + + // Capture for closure — RegisterModule's New() must hand back this + // instance so the test can assert against it. + captured := app + sessionmanager.RegisterModule(&moduleInfoStub{ + id: modID, + new: func() sessionmanager.Module { return captured }, + }) + + name := fmt.Sprintf("app-%d", i) + cfg.Apps[name] = &config.App{Mod: modID} + cfg.AppsOrder = append(cfg.AppsOrder, name) + } + + return cfg +} + +// moduleInfoStub adapts a (id, new) pair into the Module + ModuleInfo +// registration surface required by sessionmanager.RegisterModule. +type moduleInfoStub struct { + id string + new func() sessionmanager.Module +} + +func (s *moduleInfoStub) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: s.id, New: s.new} +} + +func TestStartApps_OrderAndReverseStop(t *testing.T) { + var counter atomic.Int64 + a := &fakeApp{counter: &counter} + b := &fakeApp{counter: &counter} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.NoError(t, err) + require.True(t, a.startCalled) + require.True(t, b.startCalled) + assert.Less(t, a.startOrder, b.startOrder, "A.Start must be called before B.Start") + + require.NoError(t, stop()) + assert.Less(t, b.stopOrder, a.stopOrder, "B.Stop must be called before A.Stop") +} + +func TestStartApps_StartFailureRollsBack(t *testing.T) { + var counter atomic.Int64 + wantErr := errors.New("boom") + a := &fakeApp{counter: &counter} + b := &fakeApp{counter: &counter, startErr: wantErr} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.Error(t, err) + require.ErrorIs(t, err, wantErr) + assert.Nil(t, stop) + + assert.True(t, a.startCalled) + assert.True(t, a.stopCalled, "successfully-started app must be rolled back on Start failure") + assert.True(t, b.startCalled) + assert.False(t, b.stopCalled, "failed-to-start app must not have Stop called") +} + +func TestStartApps_FirstAppFailsNoStop(t *testing.T) { + var counter atomic.Int64 + wantErr := errors.New("boom") + a := &fakeApp{counter: &counter, startErr: wantErr} + b := &fakeApp{counter: &counter} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := startApps(ctx, cfg) + require.ErrorIs(t, err, wantErr) + assert.False(t, a.stopCalled) + assert.False(t, b.startCalled) + assert.False(t, b.stopCalled) +} + +func TestStartApps_StopErrorsAggregated(t *testing.T) { + var counter atomic.Int64 + stopErrA := errors.New("stop-a") + stopErrB := errors.New("stop-b") + a := &fakeApp{counter: &counter, stopErr: stopErrA} + b := &fakeApp{counter: &counter, stopErr: stopErrB} + + cfg := registerFakeApps(t, a, b) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + stop, err := startApps(ctx, cfg) + require.NoError(t, err) + + err = stop() + require.Error(t, err) + assert.ErrorIs(t, err, stopErrA) + assert.ErrorIs(t, err, stopErrB) +} + +func TestAppsStartOrder_AppsOrderTakesPrecedence(t *testing.T) { + cfg := &config.Config{ + Apps: map[string]*config.App{ + "a": {}, "b": {}, "c": {}, + }, + AppsOrder: []string{"c", "a"}, + } + + got, err := appsStartOrder(cfg) + require.NoError(t, err) + require.Len(t, got, 3) + assert.Equal(t, "c", got[0]) + assert.Equal(t, "a", got[1]) + // "b" wasn't listed, so it appears last in map iteration order. + assert.Equal(t, "b", got[2]) +} + +func TestAppsStartOrder_RejectsUnknownAppName(t *testing.T) { + cfg := &config.Config{ + Apps: map[string]*config.App{"a": {}}, + AppsOrder: []string{"missing"}, + } + + _, err := appsStartOrder(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing") +} diff --git a/internal/business/business.go b/internal/business/business.go index b08d17eb..0501811b 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -4,29 +4,17 @@ import ( "context" "errors" "fmt" - "log/slog" "sync" "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/valkey-io/valkey-go" - - otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" slogctx "github.com/veqryn/slog-context" sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/business/server" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/session" sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" -) - -const ( - insecure = "insecure" - mtls = "mtls" - clientSecret = "client_secret" // An alias to clientSecretPost. Prefer using clientSecretPost. - clientSecretPost = "client_secret_post" + "github.com/openkcm/session-manager/internal/sessionwiring" ) // Main starts both API servers @@ -34,6 +22,8 @@ func Main(ctx context.Context, cfg *config.Config) error { c, cancelCause := sessionmanager.NewContext(ctx) defer cancelCause(nil) + c = config.WithContext(c, cfg) + if _, err := c.LoadModule(&cfg.Database); err != nil { return fmt.Errorf("loading database module: %w", err) } @@ -42,6 +32,11 @@ func Main(ctx context.Context, cfg *config.Config) error { return fmt.Errorf("loading trust module: %w", err) } + stopApps, err := startApps(c, cfg) + if err != nil { + return fmt.Errorf("starting apps: %w", err) + } + // errChan is used to capture the first error and shutdown the servers. errChan := make(chan error, 1) @@ -59,16 +54,18 @@ func Main(ctx context.Context, cfg *config.Config) error { }) // wait for any error to initiate the shutdown - err := <-errChan + err = <-errChan if err != nil { slogctx.Error(ctx, "Shutting down servers", "error", err) } + + stopErr := stopApps() cancelCause(err) // wait for all servers to shutdown wg.Wait() - return err + return errors.Join(err, stopErr) } // publicMain starts the HTTP REST public API server. @@ -91,7 +88,7 @@ func publicMain(ctx *sessionmanager.Context, cfg *config.Config) error { //nolint:forcetypeassert trust := trustMod.(sessionmanager.Trust) - sessionManager, closeFn, err := initSessionManager(ctx, cfg, trust) + sessionManager, closeFn, err := sessionwiring.InitSessionManager(ctx, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } @@ -104,14 +101,14 @@ func publicMain(ctx *sessionmanager.Context, cfg *config.Config) error { // internalMain starts the gRPC private API server. func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { // Create session repository - valkeyClient, err := valkeyClientFromConfig(cfg) + valkeyClient, err := sessionwiring.ValkeyClient(cfg) if err != nil { return fmt.Errorf("failed to create valkey client: %w", err) } defer valkeyClient.Close() sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - credsBuilder, err := newCredsBuilder(cfg) + credsBuilder, err := sessionwiring.CredsBuilder(cfg) if err != nil { return fmt.Errorf("failed to create a credentials builder: %w", err) } @@ -136,99 +133,3 @@ func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { return server.StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) } - -func initSessionManager(ctx context.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { - // Create session repository - valkeyClient, err := valkeyClientFromConfig(cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to create valkey client: %w", err) - } - sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - - credsBuilder, err := newCredsBuilder(cfg) - if err != nil { - return nil, nil, fmt.Errorf("failed to load http client: %w", err) - } - - auditLogger, err := otlpaudit.NewLogger(&cfg.Audit) - if err != nil { - return nil, nil, fmt.Errorf("failed to create audit logger: %w", err) - } - - sessManager, err := session.NewManager(ctx, - &cfg.SessionManager, - trust, - sessionRepo, - auditLogger, - session.WithTransportCredentials(credsBuilder), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create session manager: %w", err) - } - - return sessManager, valkeyClient.Close, nil -} - -func valkeyClientFromConfig(cfg *config.Config) (valkey.Client, error) { - valkeyHost, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Host) - if err != nil { - return nil, fmt.Errorf("failed to load valkey host: %w", err) - } - - valkeyUsername, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.User) - if err != nil { - return nil, fmt.Errorf("failed to load valkey username: %w", err) - } - - valkeyPassword, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Password) - if err != nil { - return nil, fmt.Errorf("failed to load valkey password: %w", err) - } - - valkeyOpts := valkey.ClientOption{ - InitAddress: []string{string(valkeyHost)}, - Username: string(valkeyUsername), - Password: string(valkeyPassword), - } - - if cfg.ValKey.SecretRef.Type == commoncfg.MTLSSecretType { - tlsConfig, err := commoncfg.LoadMTLSConfig(&cfg.ValKey.SecretRef.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load valkey mTLS config from secret ref: %w", err) - } - - valkeyOpts.TLSConfig = tlsConfig - } - - valkeyClient, err := valkey.NewClient(valkeyOpts) - if err != nil { - return nil, fmt.Errorf("failed to create a new valkey client: %w", err) - } - return valkeyClient, nil -} - -func newCredsBuilder(cfg *config.Config) (credentials.Builder, error) { - switch cfg.SessionManager.ClientAuth.Type { - case mtls: - tlsConfig, err := commoncfg.LoadMTLSConfig(cfg.SessionManager.ClientAuth.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load mTLS config: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil - case clientSecretPost, clientSecret: - secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) - if err != nil { - return nil, fmt.Errorf("failed to load client secret: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { - return credentials.NewClientSecretPost(clientID, string(secret)) - }, nil - case insecure: - slog.Warn("insecure credentials are used. Do not use this in production") - return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil - default: - return nil, errors.New("unknown Client Auth type") - } -} diff --git a/internal/business/business_test.go b/internal/business/business_test.go index cb14c057..de0292d2 100644 --- a/internal/business/business_test.go +++ b/internal/business/business_test.go @@ -1,283 +1,15 @@ package business import ( - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" "testing" "github.com/openkcm/common-sdk/pkg/commoncfg" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/credentials" ) -func TestLoadHTTPClient_MTLS(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "mtls", - ClientID: "test-client", - MTLS: &commoncfg.MTLS{ - Cert: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - // This will fail without actual cert files, but tests the logic path - _, err := newCredsBuilder(cfg) - // We expect an error since we don't have real cert files - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load mTLS config") -} - -func TestLoadHTTPClient_ClientSecret(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "client_secret", - ClientID: "test-client", - ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "test-secret"}, - }, - }, - } - - builder, err := newCredsBuilder(cfg) - require.NoError(t, err) - require.NotNil(t, builder) - - // Verify it's using our custom transport - creds := builder(cfg.SessionManager.ClientAuth.ClientID) - clientSecretCreds, ok := creds.(*credentials.ClientSecretPost) - require.True(t, ok) - - assert.Equal(t, "test-client", clientSecretCreds.ClientID) - assert.Equal(t, "test-secret", clientSecretCreds.ClientSecret) -} - -func TestLoadHTTPClient_Insecure(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: insecure, - ClientID: "test-client", - }, - }, - } - - builder, err := newCredsBuilder(cfg) - require.NoError(t, err) - assert.IsType(t, &credentials.Insecure{}, builder("")) -} - -func TestLoadHTTPClient_UnknownType(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "unknown", - ClientID: "test-client", - }, - }, - } - - _, err := newCredsBuilder(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown Client Auth type") -} - -func TestClientAuthRoundTripper_RoundTrip(t *testing.T) { - tests := []struct { - name string - clientID string - clientSecret string - requestURL string - expectedClientID string - expectedHasSecret bool - expectedSecretVal string - body io.Reader - }{ - { - name: "With client secret", - clientID: "my-client", - clientSecret: "my-secret", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: true, - expectedSecretVal: "my-secret", - }, - { - name: "Without client secret", - clientID: "my-client", - clientSecret: "", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: false, - }, - { - name: "With existing query params", - clientID: "my-client", - clientSecret: "my-secret", - requestURL: "https://example.com/token", - expectedClientID: "my-client", - expectedHasSecret: true, - expectedSecretVal: "my-secret", - body: strings.NewReader("foo=bar"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Helper() - - // Create a test server that captures the request - var capturedReq *http.Request - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = r.ParseForm() - capturedReq = r - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - // Create the round tripper - creds := credentials.NewClientSecretPost(tt.clientID, tt.clientSecret) - - // Parse the test URL - reqURL, err := url.Parse(tt.requestURL) - require.NoError(t, err) - - // Update URL to point to test server - reqURL.Scheme = "http" - reqURL.Host = server.Listener.Addr().String() - - // Create and execute request - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, reqURL.String(), tt.body) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, err) - - resp, err := creds.Transport().RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, resp) - defer resp.Body.Close() - - // Verify the captured request has correct query params - require.NotNil(t, capturedReq) - - assert.Equal(t, tt.expectedClientID, capturedReq.FormValue("client_id")) - - if tt.expectedHasSecret { - assert.Equal(t, tt.expectedSecretVal, capturedReq.FormValue("client_secret")) - } else { - assert.Empty(t, capturedReq.FormValue("client_secret")) - } - - // Verify original query params are preserved - if tt.body != nil { - b, _ := io.ReadAll(tt.body) - q, _ := url.ParseQuery(string(b)) - - for k, v := range q { - assert.Equal(t, v, capturedReq.FormValue(k)) - } - } - }) - } -} - -func TestClientAuthRoundTripper_RoundTrip_PreservesExistingParams(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "my-client", r.FormValue("client_id")) - assert.Equal(t, "my-secret", r.FormValue("client_secret")) - assert.Equal(t, "bar", r.FormValue("foo")) - assert.Equal(t, "baz", r.FormValue("param2")) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - creds := credentials.NewClientSecretPost("my-client", "my-secret") - - reqURL := server.URL - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, reqURL, strings.NewReader("foo=bar¶m2=baz")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - require.NoError(t, err) - - resp, err := creds.Transport().RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, resp) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func TestValkeyClientFromConfig_InvalidHostRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey host") -} - -func TestValkeyClientFromConfig_InvalidUserRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey username") -} - -func TestValkeyClientFromConfig_InvalidPasswordRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey password") -} - -func TestValkeyClientFromConfig_WithMTLS(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - SecretRef: commoncfg.SecretRef{ - Type: commoncfg.MTLSSecretType, - MTLS: commoncfg.MTLS{ - Cert: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - _, err := valkeyClientFromConfig(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey mTLS config from secret ref") -} - func TestPublicMain_InvalidCSRFSecret(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ @@ -323,8 +55,6 @@ func TestInternalMain_InvalidValkeyConfig(t *testing.T) { err := internalMain(ctx, cfg) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to create valkey client") - // Could fail on OIDC (DB connection) or valkey - // Error details depend on which step fails } func TestMain_InvalidCSRFSecret(t *testing.T) { diff --git a/internal/business/housekeeper.go b/internal/business/housekeeper.go index 18fe69d3..d33a9dc1 100644 --- a/internal/business/housekeeper.go +++ b/internal/business/housekeeper.go @@ -9,6 +9,7 @@ import ( sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/sessionwiring" ) // HousekeeperMain starts the house keeping jobs @@ -29,7 +30,7 @@ func HousekeeperMain(ctx context.Context, cfg *config.Config) error { //nolint:forcetypeassert trust := trustMod.(sessionmanager.Trust) - sessionManager, closeFn, err := initSessionManager(ctx, cfg, trust) + sessionManager, closeFn, err := sessionwiring.InitSessionManager(ctx, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } diff --git a/internal/business/housekeeper_test.go b/internal/business/housekeeper_test.go index cc960403..984837ed 100644 --- a/internal/business/housekeeper_test.go +++ b/internal/business/housekeeper_test.go @@ -19,7 +19,7 @@ func TestHousekeeperMain_CancelledContext(t *testing.T) { }, SessionManager: config.SessionManager{ ClientAuth: config.ClientAuth{ - Type: insecure, + Type: "insecure", }, }, } diff --git a/internal/config/config.go b/internal/config/config.go index 5442c96b..876f4cf5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,6 +26,39 @@ type Config struct { SessionManager SessionManager `yaml:"sessionManager"` Housekeeper Housekeeper `yaml:"housekeeper"` Trust Trust `yaml:"trust"` + + // Apps configures long-running components that satisfy the sessionmanager.App + // interface. The map key is an operator-chosen name. Each entry MUST set + // "module:" to the registered module ID; remaining fields are passed to the + // module via UnmarshalExtension. + Apps map[string]*App `yaml:"apps"` + // AppsOrder optionally overrides the start order of apps. Apps not listed + // here are started in parser-defined order after the listed ones. At + // shutdown, apps are stopped in the reverse of the order in which they were + // successfully started. + AppsOrder []string `yaml:"appsOrder"` +} + +// App is the per-entry configuration under the top-level apps: section. It +// implements sessionmanager.ExtensionConfig so it can be passed to LoadApp. +type App struct { + Mod string `yaml:"module"` + koanf *koanf.Koanf +} + +func (c *App) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *App) Module() string { + return c.Mod +} + +func (c *App) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) } type Trust struct { diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..6520aaee --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,72 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" +) + +// fakeAppModule is used to verify that App.UnmarshalExtension routes +// per-app YAML fields into the target module struct. +type fakeAppModule struct { + TriggerInterval string `yaml:"triggerInterval"` + Endpoint string `yaml:"endpoint"` +} + +func (*fakeAppModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: "config_test.fake.app"} +} + +func TestLoad_AppsSection(t *testing.T) { + yaml := ` +apps: + housekeeper: + module: app.module.housekeeper + triggerInterval: 5m + audit-shipper: + module: app.module.audit-shipper + endpoint: https://example.invalid/audit +appsOrder: + - housekeeper + - audit-shipper +` + + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte(yaml), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + + require.Len(t, cfg.Apps, 2) + require.Contains(t, cfg.Apps, "housekeeper") + require.Contains(t, cfg.Apps, "audit-shipper") + + assert.Equal(t, "app.module.housekeeper", cfg.Apps["housekeeper"].Module()) + assert.Equal(t, "app.module.audit-shipper", cfg.Apps["audit-shipper"].Module()) + assert.Equal(t, []string{"housekeeper", "audit-shipper"}, cfg.AppsOrder) + + // Per-app fields must be reachable via UnmarshalExtension into the target + // module type — confirms the koanf subtree is wired up per entry. + hk := &fakeAppModule{} + require.NoError(t, cfg.Apps["housekeeper"].UnmarshalExtension(hk)) + assert.Equal(t, "5m", hk.TriggerInterval) + + as := &fakeAppModule{} + require.NoError(t, cfg.Apps["audit-shipper"].UnmarshalExtension(as)) + assert.Equal(t, "https://example.invalid/audit", as.Endpoint) +} + +func TestLoad_AppsAbsentIsNoop(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("# empty\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Empty(t, cfg.Apps) + assert.Empty(t, cfg.AppsOrder) +} diff --git a/internal/config/context.go b/internal/config/context.go new file mode 100644 index 00000000..b86e7f7e --- /dev/null +++ b/internal/config/context.go @@ -0,0 +1,26 @@ +package config + +import ( + "context" + + sessionmanager "github.com/openkcm/session-manager" +) + +// configCtxKey is a private type so config attached via WithContext can only +// be retrieved by callers that share this package. +type configCtxKey struct{} + +// WithContext returns a sessionmanager.Context that carries cfg. Apps' +// Provision methods retrieve it via FromContext when they need top-level +// configuration that is not part of their own per-app config block (e.g. +// valkey credentials, audit endpoints). +func WithContext(ctx *sessionmanager.Context, cfg *Config) *sessionmanager.Context { + return ctx.WithValue(configCtxKey{}, cfg) +} + +// FromContext returns the *Config previously attached via WithContext. The +// boolean is false when no config has been attached. +func FromContext(ctx context.Context) (*Config, bool) { + cfg, ok := ctx.Value(configCtxKey{}).(*Config) + return cfg, ok +} diff --git a/internal/config/load.go b/internal/config/load.go index 8f4fc893..ef68e0f8 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -87,7 +87,8 @@ func setKoanf(v reflect.Value, ko *koanf.Koanf) { } elem := reflect.Indirect(v) - if elem.Kind() == reflect.Struct { + switch elem.Kind() { + case reflect.Struct: for field, val := range elem.Fields() { name, _, _ := strings.Cut(field.Tag.Get(koanfUnmarshalConf.Tag), ",") if name == "" { @@ -102,5 +103,19 @@ func setKoanf(v reflect.Value, ko *koanf.Koanf) { setKoanf(val, ko.Cut(name)) } + case reflect.Map: + // Recurse into string-keyed map values so each entry receives its own + // koanf subtree. Map values returned by MapRange are not addressable, + // so we only descend when the value type is already a pointer. + if elem.Type().Key().Kind() != reflect.String { + return + } + if elem.Type().Elem().Kind() != reflect.Pointer { + return + } + iter := elem.MapRange() + for iter.Next() { + setKoanf(iter.Value(), ko.Cut(iter.Key().String())) + } } } diff --git a/internal/grpc/oidcmapping_test.go b/internal/grpc/trust_test.go similarity index 100% rename from internal/grpc/oidcmapping_test.go rename to internal/grpc/trust_test.go diff --git a/internal/sessionwiring/sessionwiring.go b/internal/sessionwiring/sessionwiring.go new file mode 100644 index 00000000..94d036c7 --- /dev/null +++ b/internal/sessionwiring/sessionwiring.go @@ -0,0 +1,130 @@ +// Package sessionwiring centralises the construction of long-lived +// session-manager dependencies (valkey client, credentials builder, the +// session.Manager itself) so callers in cmd/, internal/business, and apps +// configured via the apps: lifecycle loop can build them identically. +package sessionwiring + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/valkey-io/valkey-go" + + otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" + "github.com/openkcm/session-manager/internal/session" + sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" +) + +const ( + insecure = "insecure" + mtls = "mtls" + clientSecret = "client_secret" // Alias for clientSecretPost. + clientSecretPost = "client_secret_post" +) + +// InitSessionManager builds a session.Manager from the supplied config and +// trust module. The returned closeFn must be invoked once the manager is no +// longer in use to release the underlying valkey client. +func InitSessionManager(ctx context.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { + valkeyClient, err := ValkeyClient(cfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to create valkey client: %w", err) + } + sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) + + credsBuilder, err := CredsBuilder(cfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to load http client: %w", err) + } + + auditLogger, err := otlpaudit.NewLogger(&cfg.Audit) + if err != nil { + return nil, nil, fmt.Errorf("failed to create audit logger: %w", err) + } + + sessManager, err := session.NewManager(ctx, + &cfg.SessionManager, + trust, + sessionRepo, + auditLogger, + session.WithTransportCredentials(credsBuilder), + ) + if err != nil { + return nil, nil, fmt.Errorf("failed to create session manager: %w", err) + } + + return sessManager, valkeyClient.Close, nil +} + +// ValkeyClient creates a valkey client from the valkey-related fields on cfg. +func ValkeyClient(cfg *config.Config) (valkey.Client, error) { + valkeyHost, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Host) + if err != nil { + return nil, fmt.Errorf("failed to load valkey host: %w", err) + } + + valkeyUsername, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.User) + if err != nil { + return nil, fmt.Errorf("failed to load valkey username: %w", err) + } + + valkeyPassword, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Password) + if err != nil { + return nil, fmt.Errorf("failed to load valkey password: %w", err) + } + + valkeyOpts := valkey.ClientOption{ + InitAddress: []string{string(valkeyHost)}, + Username: string(valkeyUsername), + Password: string(valkeyPassword), + } + + if cfg.ValKey.SecretRef.Type == commoncfg.MTLSSecretType { + tlsConfig, err := commoncfg.LoadMTLSConfig(&cfg.ValKey.SecretRef.MTLS) + if err != nil { + return nil, fmt.Errorf("failed to load valkey mTLS config from secret ref: %w", err) + } + valkeyOpts.TLSConfig = tlsConfig + } + + valkeyClient, err := valkey.NewClient(valkeyOpts) + if err != nil { + return nil, fmt.Errorf("failed to create a new valkey client: %w", err) + } + return valkeyClient, nil +} + +// CredsBuilder returns a credentials.Builder that matches the configured +// client-auth strategy. +func CredsBuilder(cfg *config.Config) (credentials.Builder, error) { + switch cfg.SessionManager.ClientAuth.Type { + case mtls: + tlsConfig, err := commoncfg.LoadMTLSConfig(cfg.SessionManager.ClientAuth.MTLS) + if err != nil { + return nil, fmt.Errorf("failed to load mTLS config: %w", err) + } + + return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil + case clientSecretPost, clientSecret: + secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to load client secret: %w", err) + } + + return func(clientID string) credentials.TransportCredentials { + return credentials.NewClientSecretPost(clientID, string(secret)) + }, nil + case insecure: + slog.Warn("insecure credentials are used. Do not use this in production") + return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil + default: + return nil, errors.New("unknown Client Auth type") + } +} diff --git a/internal/sessionwiring/sessionwiring_test.go b/internal/sessionwiring/sessionwiring_test.go new file mode 100644 index 00000000..794f75b3 --- /dev/null +++ b/internal/sessionwiring/sessionwiring_test.go @@ -0,0 +1,147 @@ +package sessionwiring + +import ( + "testing" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" +) + +func TestCredsBuilder_MTLS(t *testing.T) { + cfg := &config.Config{ + SessionManager: config.SessionManager{ + ClientAuth: config.ClientAuth{ + Type: "mtls", + ClientID: "test-client", + MTLS: &commoncfg.MTLS{ + Cert: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, + CertKey: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, + }, + }, + }, + } + + _, err := CredsBuilder(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load mTLS config") +} + +func TestCredsBuilder_ClientSecret(t *testing.T) { + cfg := &config.Config{ + SessionManager: config.SessionManager{ + ClientAuth: config.ClientAuth{ + Type: "client_secret", + ClientID: "test-client", + ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "test-secret"}, + }, + }, + } + + builder, err := CredsBuilder(cfg) + require.NoError(t, err) + require.NotNil(t, builder) + + creds := builder(cfg.SessionManager.ClientAuth.ClientID) + clientSecretCreds, ok := creds.(*credentials.ClientSecretPost) + require.True(t, ok) + + assert.Equal(t, "test-client", clientSecretCreds.ClientID) + assert.Equal(t, "test-secret", clientSecretCreds.ClientSecret) +} + +func TestCredsBuilder_Insecure(t *testing.T) { + cfg := &config.Config{ + SessionManager: config.SessionManager{ + ClientAuth: config.ClientAuth{ + Type: "insecure", + ClientID: "test-client", + }, + }, + } + + builder, err := CredsBuilder(cfg) + require.NoError(t, err) + assert.IsType(t, &credentials.Insecure{}, builder("")) +} + +func TestCredsBuilder_UnknownType(t *testing.T) { + cfg := &config.Config{ + SessionManager: config.SessionManager{ + ClientAuth: config.ClientAuth{ + Type: "unknown", + ClientID: "test-client", + }, + }, + } + + _, err := CredsBuilder(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown Client Auth type") +} + +func TestValkeyClient_InvalidHostRef(t *testing.T) { + cfg := &config.Config{ + ValKey: config.ValKey{ + Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, + User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, + Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, + }, + } + + _, err := ValkeyClient(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load valkey host") +} + +func TestValkeyClient_InvalidUserRef(t *testing.T) { + cfg := &config.Config{ + ValKey: config.ValKey{ + Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, + User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, + Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, + }, + } + + _, err := ValkeyClient(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load valkey username") +} + +func TestValkeyClient_InvalidPasswordRef(t *testing.T) { + cfg := &config.Config{ + ValKey: config.ValKey{ + Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, + User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, + Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, + }, + } + + _, err := ValkeyClient(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load valkey password") +} + +func TestValkeyClient_WithMTLS(t *testing.T) { + cfg := &config.Config{ + ValKey: config.ValKey{ + Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, + User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, + Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, + SecretRef: commoncfg.SecretRef{ + Type: commoncfg.MTLSSecretType, + MTLS: commoncfg.MTLS{ + Cert: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, + CertKey: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, + }, + }, + }, + } + + _, err := ValkeyClient(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load valkey mTLS config from secret ref") +} diff --git a/modules.go b/modules.go index 2f3a4f25..280652e3 100644 --- a/modules.go +++ b/modules.go @@ -55,3 +55,8 @@ type ModuleInfo struct { type Provisioner interface { Provision(ctx *Context) error } + +type App interface { + Start() error + Stop() error +} diff --git a/modules/oidctrust/mapping.go b/modules/oidctrust/trust.go similarity index 82% rename from modules/oidctrust/mapping.go rename to modules/oidctrust/trust.go index abc020e1..fe6d2e9b 100644 --- a/modules/oidctrust/mapping.go +++ b/modules/oidctrust/trust.go @@ -11,7 +11,7 @@ import ( "github.com/openkcm/session-manager/pkg/serviceerr" ) -// ApplyMapping applies and stores the provided Trust. +// ApplyMapping implements [sessionmanager.Trust]. func (m *TrustModule) ApplyMapping(ctx context.Context, trust *trustv1.Trust) error { if _, err := m.repository.Get(ctx, trust.GetTenantId()); err != nil { err = m.repository.Create(ctx, trust) @@ -28,9 +28,7 @@ func (m *TrustModule) ApplyMapping(ctx context.Context, trust *trustv1.Trust) er return nil } -// BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. -// If the mapping is already blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. +// BlockMapping implements [sessionmanager.Trust]. func (m *TrustModule) BlockMapping(ctx context.Context, tenantID string) error { trust, err := m.repository.Get(ctx, tenantID) if err != nil { @@ -53,6 +51,7 @@ func (m *TrustModule) BlockMapping(ctx context.Context, tenantID string) error { return nil } +// RemoveMapping implements [sessionmanager.Trust]. func (m *TrustModule) RemoveMapping(ctx context.Context, tenantID string) error { err := m.repository.Delete(ctx, tenantID) if err != nil { @@ -62,9 +61,7 @@ func (m *TrustModule) RemoveMapping(ctx context.Context, tenantID string) error return nil } -// UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. -// If the mapping is not blocked, it does nothing. -// Returns an error if the mapping cannot be retrieved or updated. +// UnblockMapping implements [sessionmanager.Trust]. func (m *TrustModule) UnblockMapping(ctx context.Context, tenantID string) error { trust, err := m.repository.Get(ctx, tenantID) if err != nil { @@ -86,7 +83,7 @@ func (m *TrustModule) UnblockMapping(ctx context.Context, tenantID string) error return nil } -// Get returns a trust message with optional extensions set. +// Get implements [sessionmanager.Trust]. func (m *TrustModule) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { trust, err := m.repository.Get(ctx, tenantID) if err != nil { diff --git a/modules/oidctrust/mapping_test.go b/modules/oidctrust/trust_test.go similarity index 100% rename from modules/oidctrust/mapping_test.go rename to modules/oidctrust/trust_test.go diff --git a/trust.go b/trust.go index 4ed79335..7d69d1e2 100644 --- a/trust.go +++ b/trust.go @@ -7,9 +7,18 @@ import ( ) type Trust interface { + // ApplyMapping applies and stores the provided Trust. ApplyMapping(ctx context.Context, trust *trustv1.Trust) error + // BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. + // If the mapping is already blocked, it does nothing. + // Returns an error if the mapping cannot be retrieved or updated. BlockMapping(ctx context.Context, tenantID string) error + // RemoveMapping removes the specified mapping from the trust. RemoveMapping(ctx context.Context, tenantID string) error + // UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. + // If the mapping is not blocked, it does nothing. + // Returns an error if the mapping cannot be retrieved or updated. UnblockMapping(ctx context.Context, tenantID string) error + // Get returns a trust message with optional extensions set. Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) } From 5d059c40095cc9696355de999ffc6be54691589a Mon Sep 17 00:00:00 2001 From: Danylo Shevchenko Date: Tue, 26 May 2026 09:01:28 +0200 Subject: [PATCH 3/5] refactor: rename mapping into trust Signed-off-by: Danylo Shevchenko --- integration/grpc_test.go | 4 +- integration/session_grpc_test.go | 4 +- internal/business/business.go | 5 +- internal/business/server/grpc_server.go | 6 +- internal/business/server/grpc_server_test.go | 4 +- internal/business/server/openapi.go | 2 +- internal/business/server/openapi_test.go | 2 +- internal/grpc/session.go | 4 +- internal/grpc/session_test.go | 50 +++--- internal/grpc/trust_test.go | 52 +++--- internal/grpc/trustmapping.go | 36 ++--- internal/session/housekeeper.go | 2 +- internal/session/housekeeper_test.go | 20 +-- internal/session/manager.go | 12 +- internal/session/manager_test.go | 26 +-- internal/session/mock/repository.go | 2 +- internal/session/valkey/repository.go | 2 +- modules/oidctrust/internal/sql/queries.sql | 8 +- modules/oidctrust/internal/sql/queries/db.go | 2 +- .../internal/sql/queries/queries.sql.go | 34 ++-- modules/oidctrust/internal/sql/sql.go | 18 +-- modules/oidctrust/internal/sql/sql_test.go | 80 +++++----- modules/oidctrust/mocks/repository.go | 22 +-- modules/oidctrust/repository.go | 2 +- modules/oidctrust/trust.go | 30 ++-- modules/oidctrust/trust_test.go | 150 +++++++++--------- trust.go | 24 +-- 27 files changed, 302 insertions(+), 301 deletions(-) diff --git a/integration/grpc_test.go b/integration/grpc_test.go index d482c1bd..d40e7550 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -192,7 +192,7 @@ func TestGRPCServer(t *testing.T) { assert.True(t, applyResp.GetSuccess()) }) - t.Run("ApplyTrustMapping idempotent - applying same mapping twice", func(t *testing.T) { + t.Run("ApplyTrustMapping idempotent - applying same trust twice", func(t *testing.T) { expJwks := "jks-idempotent" expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() @@ -304,7 +304,7 @@ func TestGRPCServer(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() expIssuer := uuid.Must(uuid.NewV4()).String() - // Apply mapping + // Apply trust applyRes, err := trust.ApplyTrustMapping(ctx, trustmappingv1.ApplyTrustMappingRequest_builder{ TenantId: &expTenantID, Oidc: trustmappingv1.ApplyTrustMappingRequest_ApplyOIDCTrust_builder{ diff --git a/integration/session_grpc_test.go b/integration/session_grpc_test.go index d5658adf..f9f92258 100644 --- a/integration/session_grpc_test.go +++ b/integration/session_grpc_test.go @@ -152,7 +152,7 @@ func TestSessionGRPC(t *testing.T) { err = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) require.NoError(t, err) - // Note: This test will fail validation because there's no trust mapping configured + // Note: This test will fail validation because there's no trust configured // but it tests the session retrieval path resp, err := sessionClient.GetSession(ctx, &sessionv1.GetSessionRequest{ SessionId: sess.ID, @@ -161,7 +161,7 @@ func TestSessionGRPC(t *testing.T) { }) assert.NoError(t, err) assert.NotNil(t, resp) - // Will be false because trust mapping is not configured, but tests the flow + // Will be false because trust is not configured, but tests the flow assert.False(t, resp.GetValid()) }) diff --git a/internal/business/business.go b/internal/business/business.go index 0501811b..3b25105a 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/openkcm/common-sdk/pkg/commoncfg" + slogctx "github.com/veqryn/slog-context" sessionmanager "github.com/openkcm/session-manager" @@ -122,7 +123,7 @@ func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { trust := trustMod.(sessionmanager.Trust) // Initialize the gRPC servers. - oidcmappingsrv := grpc.NewTrustMappingServer(trust) + trustsrv := grpc.NewTrustMappingServer(trust) sessionsrv := grpc.NewSessionServer(ctx, sessionRepo, trust, @@ -131,5 +132,5 @@ func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { grpc.WithTransportCredentials(credsBuilder), ) - return server.StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) + return server.StartGRPCServer(ctx, cfg, trustsrv, sessionsrv) } diff --git a/internal/business/server/grpc_server.go b/internal/business/server/grpc_server.go index 596bbc20..b541bec3 100644 --- a/internal/business/server/grpc_server.go +++ b/internal/business/server/grpc_server.go @@ -16,13 +16,13 @@ import ( ) func StartGRPCServer(ctx context.Context, cfg *config.Config, - oidcmappingsrv *grpc.TrustMappingServer, + trustsrv *grpc.TrustMappingServer, sessionsrv *grpc.SessionServer, ) error { grpcServer := commongrpc.NewServer(ctx, &cfg.GRPC.GRPCServer) - // Register OIDC mapping server for the regional tenant manager - trustmappingv1.RegisterServiceServer(grpcServer, oidcmappingsrv) + // Register Trust server for the regional tenant manager + trustmappingv1.RegisterServiceServer(grpcServer, trustsrv) // Register Session server for ExtAuthZ sessionv1.RegisterServiceServer(grpcServer, sessionsrv) diff --git a/internal/business/server/grpc_server_test.go b/internal/business/server/grpc_server_test.go index 01cb928d..be0fb226 100644 --- a/internal/business/server/grpc_server_test.go +++ b/internal/business/server/grpc_server_test.go @@ -26,13 +26,13 @@ func TestStartGRPCServer_ContextCancellation(t *testing.T) { } // Create minimal server instances - oidcmappingsrv := grpc.NewTrustMappingServer(nil) + trustsrv := grpc.NewTrustMappingServer(nil) sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0, "") // Start the server in a goroutine errChan := make(chan error, 1) go func() { - errChan <- StartGRPCServer(ctx, cfg, oidcmappingsrv, sessionsrv) + errChan <- StartGRPCServer(ctx, cfg, trustsrv, sessionsrv) }() // Give the server a moment to start diff --git a/internal/business/server/openapi.go b/internal/business/server/openapi.go index d9695139..4ea4b6db 100644 --- a/internal/business/server/openapi.go +++ b/internal/business/server/openapi.go @@ -16,8 +16,8 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) // sessionManager defines the interface for session management operations diff --git a/internal/business/server/openapi_test.go b/internal/business/server/openapi_test.go index 59e9b207..79d94706 100644 --- a/internal/business/server/openapi_test.go +++ b/internal/business/server/openapi_test.go @@ -15,8 +15,8 @@ import ( "github.com/openkcm/session-manager/internal/middleware" "github.com/openkcm/session-manager/internal/openapi" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) const ( diff --git a/internal/grpc/session.go b/internal/grpc/session.go index 10728e6b..a71ad1e9 100644 --- a/internal/grpc/session.go +++ b/internal/grpc/session.go @@ -113,8 +113,8 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio trust, err := s.trust.Get(ctx, req.GetTenantId()) if err != nil { span.RecordError(err) - span.SetStatus(codes.Error, "failed to get an oidc mapping") - slogctx.Warn(ctx, "Is this an attack? Could not get trust mapping", "issuer", sess.Issuer, "error", err) + span.SetStatus(codes.Error, "failed to get trust") + slogctx.Warn(ctx, "Is this an attack? Could not get trust", "issuer", sess.Issuer, "error", err) return &sessionv1.GetSessionResponse{Valid: false}, nil } if trust.GetBlocked() { diff --git a/internal/grpc/session_test.go b/internal/grpc/session_test.go index 69335f25..04885d28 100644 --- a/internal/grpc/session_test.go +++ b/internal/grpc/session_test.go @@ -112,7 +112,7 @@ func TestGetSession(t *testing.T) { AuthContext: map[string]string{"key": "value"}, } - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -126,7 +126,7 @@ func TestGetSession(t *testing.T) { // Mark session as active _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), @@ -184,7 +184,7 @@ func TestGetSession(t *testing.T) { }, } - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -197,7 +197,7 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), @@ -241,7 +241,7 @@ func TestGetSession(t *testing.T) { }, } - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -254,7 +254,7 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", grpc.WithAllowHttpScheme(true), @@ -351,7 +351,7 @@ func TestGetSession(t *testing.T) { assert.False(t, resp.GetValid()) }) - t.Run("invalid - trust mapping not found", func(t *testing.T) { + t.Run("invalid - trust not found", func(t *testing.T) { sess := session.Session{ ID: "session-no-provider", TenantID: "tenant-no-provider", @@ -364,7 +364,7 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - // No mapping added to repo + // No trust added to repo trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") @@ -382,7 +382,7 @@ func TestGetSession(t *testing.T) { assert.False(t, resp.GetValid()) }) - t.Run("invalid - trust mapping is blocked", func(t *testing.T) { + t.Run("invalid - trust is blocked", func(t *testing.T) { sess := session.Session{ ID: "session-blocked", TenantID: "tenant-blocked", @@ -395,14 +395,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://issuer.example.com"), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") @@ -445,14 +445,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://issuer.example.com"), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") @@ -483,14 +483,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new("wrong-tenant"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://issuer.example.com"), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") @@ -521,14 +521,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://invalid-issuer-no-server.example.com"), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") @@ -575,14 +575,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new(testServer.URL), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", @@ -633,14 +633,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new(testServer.URL), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", @@ -684,14 +684,14 @@ func TestGetSession(t *testing.T) { ) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(sess.TenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ Issuer: new(testServer.URL), }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", @@ -735,7 +735,7 @@ func TestGetOIDCProvider(t *testing.T) { ctx := t.Context() t.Run("success - returns OIDC provider", func(t *testing.T) { - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -744,7 +744,7 @@ func TestGetOIDCProvider(t *testing.T) { Audiences: []string{"audience1", "audience2"}, }.Build(), }.Build() - trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) sessionRepo := sessionmock.NewInMemRepository() diff --git a/internal/grpc/trust_test.go b/internal/grpc/trust_test.go index f17796ed..37f55e08 100644 --- a/internal/grpc/trust_test.go +++ b/internal/grpc/trust_test.go @@ -14,8 +14,8 @@ import ( trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/pkg/serviceerr" mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" ) func TestNewTrustMappingServer(t *testing.T) { @@ -31,7 +31,7 @@ func TestNewTrustMappingServer(t *testing.T) { func TestApplyTrustMapping(t *testing.T) { ctx := t.Context() - t.Run("success - creates new mapping", func(t *testing.T) { + t.Run("success - creates new trust", func(t *testing.T) { repo := mocktrust.NewInMemRepository() svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -54,8 +54,8 @@ func TestApplyTrustMapping(t *testing.T) { assert.Empty(t, resp.GetMessage()) }) - t.Run("success - updates existing mapping", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + t.Run("success - updates existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://old-issuer.example.com"), @@ -64,7 +64,7 @@ func TestApplyTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -136,11 +136,11 @@ func TestApplyTrustMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to apply Trust mapping") + assert.Contains(t, st.Message(), "failed to apply trust") }) t.Run("update error - returns grpc error", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://issuer.example.com"), @@ -148,7 +148,7 @@ func TestApplyTrustMapping(t *testing.T) { }.Build() updateErr := errors.New("update failed") repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), mocktrust.WithUpdateError(updateErr), ) svc := newTrust(repo) @@ -177,8 +177,8 @@ func TestApplyTrustMapping(t *testing.T) { func TestBlockTrustMapping(t *testing.T) { ctx := t.Context() - t.Run("success - blocks existing mapping", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + t.Run("success - blocks existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -186,7 +186,7 @@ func TestBlockTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -204,7 +204,7 @@ func TestBlockTrustMapping(t *testing.T) { }) t.Run("success - already blocked", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -212,7 +212,7 @@ func TestBlockTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -268,22 +268,22 @@ func TestBlockTrustMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to block Trust mapping") + assert.Contains(t, st.Message(), "failed to block trust") }) } func TestRemoveTrustMapping(t *testing.T) { ctx := t.Context() - t.Run("success - removes existing mapping", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + t.Run("success - removes existing trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Oidc: oidcv1.OIDC_builder{ Issuer: new("https://issuer.example.com"), }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -322,7 +322,7 @@ func TestRemoveTrustMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to remove Trust mapping") + assert.Contains(t, st.Message(), "failed to remove trust") }) t.Run("error - delete is indempotent", func(t *testing.T) { @@ -348,8 +348,8 @@ func TestRemoveTrustMapping(t *testing.T) { func TestUnblockTrustMapping(t *testing.T) { ctx := t.Context() - t.Run("success - unblocks blocked mapping", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + t.Run("success - unblocks blocked trust", func(t *testing.T) { + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -357,7 +357,7 @@ func TestUnblockTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -375,7 +375,7 @@ func TestUnblockTrustMapping(t *testing.T) { }) t.Run("success - already unblocked", func(t *testing.T) { - existingMapping := trustv1.Trust_builder{ + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -383,7 +383,7 @@ func TestUnblockTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) server := grpc.NewTrustMappingServer(svc) @@ -419,7 +419,7 @@ func TestUnblockTrustMapping(t *testing.T) { t.Run("error - returns grpc error with message", func(t *testing.T) { internalErr := errors.New("update failed") - existingMapping := trustv1.Trust_builder{ + existingTrust := trustv1.Trust_builder{ TenantId: new("tenant-123"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -427,7 +427,7 @@ func TestUnblockTrustMapping(t *testing.T) { }.Build(), }.Build() repo := mocktrust.NewInMemRepository( - mocktrust.WithTrust(existingMapping), + mocktrust.WithTrust(existingTrust), mocktrust.WithUpdateError(internalErr), ) svc := newTrust(repo) @@ -447,6 +447,6 @@ func TestUnblockTrustMapping(t *testing.T) { st, ok := status.FromError(err) require.True(t, ok) assert.Equal(t, codes.Internal, st.Code()) - assert.Contains(t, st.Message(), "failed to unblock Trust mapping") + assert.Contains(t, st.Message(), "failed to unblock trust") }) } diff --git a/internal/grpc/trustmapping.go b/internal/grpc/trustmapping.go index 54b4a936..a691b4bf 100644 --- a/internal/grpc/trustmapping.go +++ b/internal/grpc/trustmapping.go @@ -57,15 +57,15 @@ func (srv *TrustMappingServer) ApplyTrustMapping(ctx context.Context, in *trustm response := trustmappingv1.ApplyTrustMappingResponse_builder{}.Build() - if err := srv.trust.ApplyMapping(ctx, trust); err != nil { - slogctx.Error(ctx, "Could not apply Trust mapping", "error", err) + if err := srv.trust.Apply(ctx, trust); err != nil { + slogctx.Error(ctx, "Could not apply trust", "error", err) if errors.Is(err, serviceerr.ErrNotFound) { msg := serviceerr.ErrNotFound.Error() response.SetMessage(msg) return response, nil } - return nil, status.Errorf(codes.Internal, "failed to apply Trust mapping: %v", err) + return nil, status.Errorf(codes.Internal, "failed to apply trust: %v", err) } response.SetSuccess(true) @@ -73,42 +73,42 @@ func (srv *TrustMappingServer) ApplyTrustMapping(ctx context.Context, in *trustm return response, nil } -// BlockTrustMapping blocks the Trust mapping for the specified tenant. -// It calls the underlying service to set the mapping as blocked. +// BlockTrustMapping blocks the trust for the specified tenant. +// It calls the underlying service to set the trust as blocked. // Returns a response containing an optional error message if blocking fails. func (srv *TrustMappingServer) BlockTrustMapping(ctx context.Context, req *trustmappingv1.BlockTrustMappingRequest) (*trustmappingv1.BlockTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "BlockTrustMapping called") resp := trustmappingv1.BlockTrustMappingResponse_builder{}.Build() - err := srv.trust.BlockMapping(ctx, req.GetTenantId()) + err := srv.trust.Block(ctx, req.GetTenantId()) if err != nil { - slogctx.Error(ctx, "Could not block Trust mapping", "error", err) + slogctx.Error(ctx, "Could not block trust", "error", err) msg := err.Error() resp.SetMessage(msg) - return resp, status.Error(codes.Internal, "failed to block Trust mapping: "+msg) + return resp, status.Error(codes.Internal, "failed to block trust: "+msg) } resp.SetSuccess(true) return resp, nil } -// RemoveTrustMapping removes the Trust configuration for the tenant. -// It calls the underlying service to remove the mapping. +// RemoveTrustMapping removes the trust configuration for the tenant. +// It calls the underlying service to remove the trust. // Returns a respose containing an optional error message if removing fails. func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trustmappingv1.RemoveTrustMappingRequest) (*trustmappingv1.RemoveTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "RemoveTrustMapping called") resp := &trustmappingv1.RemoveTrustMappingResponse{} - err := srv.trust.RemoveMapping(ctx, req.GetTenantId()) + err := srv.trust.Remove(ctx, req.GetTenantId()) if err != nil { if !errors.Is(err, serviceerr.ErrNotFound) { - slogctx.Error(ctx, "Could not remove Trust mapping", "error", err) + slogctx.Error(ctx, "Could not remove trust", "error", err) msg := err.Error() resp.SetMessage(msg) - return resp, status.Error(codes.Internal, "failed to remove Trust mapping: "+msg) + return resp, status.Error(codes.Internal, "failed to remove trust: "+msg) } else { slogctx.Warn(ctx, "RemoveTrustMapping is called but the tenant does not exist", "error", err) } @@ -118,20 +118,20 @@ func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trus return resp, nil } -// UnblockTrustMapping unblocks the Trust mapping for the specified tenant. -// It calls the underlying service to set the mapping as unblocked. +// UnblockTrustMapping unblocks the trust for the specified tenant. +// It calls the underlying service to set the trust as unblocked. // Returns a response containing an optional error message if unblocking fails. func (srv *TrustMappingServer) UnblockTrustMapping(ctx context.Context, req *trustmappingv1.UnblockTrustMappingRequest) (*trustmappingv1.UnblockTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "UnblockTrustMapping called") resp := &trustmappingv1.UnblockTrustMappingResponse{} - err := srv.trust.UnblockMapping(ctx, req.GetTenantId()) + err := srv.trust.Unblock(ctx, req.GetTenantId()) if err != nil { - slogctx.Error(ctx, "Could not unblock Trust mapping", "error", err) + slogctx.Error(ctx, "Could not unblock trust", "error", err) msg := err.Error() resp.SetMessage(msg) - return resp, status.Error(codes.Internal, "failed to unblock Trust mapping: "+msg) + return resp, status.Error(codes.Internal, "failed to unblock trust: "+msg) } resp.SetSuccess(true) diff --git a/internal/session/housekeeper.go b/internal/session/housekeeper.go index 1d0928a1..dde2297d 100644 --- a/internal/session/housekeeper.go +++ b/internal/session/housekeeper.go @@ -98,7 +98,7 @@ func (m *Manager) housekeepSession(ctx context.Context, s Session, refreshTrigge func (m *Manager) refreshAccessToken(ctx context.Context, s Session) error { trust, err := m.trust.Get(ctx, s.TenantID) if err != nil { - return fmt.Errorf("could not get trust mapping: %w", err) + return fmt.Errorf("could not get trust: %w", err) } oidc := trust.GetOidc() diff --git a/internal/session/housekeeper_test.go b/internal/session/housekeeper_test.go index 4367d445..9ee88f28 100644 --- a/internal/session/housekeeper_test.go +++ b/internal/session/housekeeper_test.go @@ -15,10 +15,10 @@ import ( trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" ) func TestDeleteIdleSessions(t *testing.T) { @@ -106,14 +106,14 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(tenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(discoveryServerURL), }.Build(), }.Build() - oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(oidcRepo) sess := session.Session{ @@ -156,7 +156,7 @@ func TestRefreshAccessToken(t *testing.T) { assert.Equal(t, "new-refresh-token", updatedSess.RefreshToken) }) - t.Run("Error - trust mapping not found", func(t *testing.T) { + t.Run("Error - trust not found", func(t *testing.T) { oidcRepo := mocktrust.NewInMemRepository() trust := newTrust(oidcRepo) @@ -208,14 +208,14 @@ func TestRefreshAccessToken(t *testing.T) { defer tokenServer.Close() tokenServerURL = tokenServer.URL + "/token" - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(tenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(discoveryServerURL), }.Build(), }.Build() - oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) sess := session.Session{ ID: sessionID, @@ -268,14 +268,14 @@ func TestRefreshAccessToken(t *testing.T) { defer discoveryServer.Close() discoveryServerURL = discoveryServer.URL - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(tenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(discoveryServer.URL), }.Build(), }.Build() - oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) sess := session.Session{ ID: sessionID, @@ -459,14 +459,14 @@ func TestHousekeepSession_ErrorCases(t *testing.T) { })) defer tokenServer.Close() - mapping := trustv1.Trust_builder{ + trustData := trustv1.Trust_builder{ TenantId: new(tenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(discoveryServerURL), }.Build(), }.Build() - oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(mapping)) + oidcRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(oidcRepo) sess := session.Session{ diff --git a/internal/session/manager.go b/internal/session/manager.go index 17d5eb60..455696d7 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -114,7 +114,7 @@ func NewManager( func (m *Manager) MakeAuthURI(ctx context.Context, tenantID, fingerprint, requestURI string) (string, string, error) { trust, err := m.trust.Get(ctx, tenantID) if err != nil { - return "", "", fmt.Errorf("getting trust mapping: %w", err) + return "", "", fmt.Errorf("getting trust: %w", err) } oidc := trust.GetOidc() @@ -228,8 +228,8 @@ func (m *Manager) FinaliseOIDCLogin(ctx context.Context, stateID, code, fingerpr trust, err := m.trust.Get(ctx, state.TenantID) if err != nil { - m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get trust mapping") - return OIDCSessionData{}, fmt.Errorf("getting trust mapping: %w", err) + m.sendUserLoginFailureAudit(ctx, metadata, state.TenantID, "failed to get trust") + return OIDCSessionData{}, fmt.Errorf("getting trust: %w", err) } oidc := trust.GetOidc() @@ -375,8 +375,8 @@ func (m *Manager) Logout(ctx context.Context, sessionID, postLogoutRedirectURL s trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { - slogctx.Error(ctx, "failed to get trust mapping for a tenant", "error", err) - return "", fmt.Errorf("getting trust mapping: %w", err) + slogctx.Error(ctx, "failed to get trust for a tenant", "error", err) + return "", fmt.Errorf("getting trust: %w", err) } oidc := trust.GetOidc() @@ -471,7 +471,7 @@ func (m *Manager) BCLogout(ctx context.Context, logoutJWT string) error { trust, err := m.trust.Get(ctx, session.TenantID) if err != nil { - return fmt.Errorf("getting trust mapping: %w", err) + return fmt.Errorf("getting trust: %w", err) } oidc := trust.GetOidc() diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go index ab79edad..f37866b6 100644 --- a/internal/session/manager_test.go +++ b/internal/session/manager_test.go @@ -50,7 +50,7 @@ func TestManager_Auth(t *testing.T) { auditServer := StartAuditServer(t) defer auditServer.Close() - oidcMapping := trustv1.Trust_builder{ + oidcTrust := trustv1.Trust_builder{ TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -70,11 +70,11 @@ func TestManager_Auth(t *testing.T) { fingerprint string wantURL string errAssert assert.ErrorAssertionFunc - mapping *trustv1.Trust + trust *trustv1.Trust }{ { name: "Success", - oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcTrust)), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, cfg: &config.SessionManager{ @@ -91,10 +91,10 @@ func TestManager_Auth(t *testing.T) { errAssert: assert.NoError, }, { - name: "Get trust mapping error", + name: "Get trust error", oidc: mocktrust.NewInMemRepository( - mocktrust.WithTrust(oidcMapping), - mocktrust.WithGetError(errors.New("failed to get trust mapping")), + mocktrust.WithTrust(oidcTrust), + mocktrust.WithGetError(errors.New("failed to get trust")), ), sessions: sessionmock.NewInMemRepository(), requestURI: requestURI, @@ -110,7 +110,7 @@ func TestManager_Auth(t *testing.T) { }, { name: "Save state error", - oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcMapping)), + oidc: mocktrust.NewInMemRepository(mocktrust.WithTrust(oidcTrust)), sessions: sessionmock.NewInMemRepository(sessionmock.WithStoreStateError(errors.New("failed to save state"))), requestURI: requestURI, cfg: &config.SessionManager{ @@ -155,7 +155,7 @@ func TestManager_Auth(t *testing.T) { } // Validate that the data has been inserted into the repository - assert.Equal(t, oidcMapping, tt.oidc.TGet(tt.tenantID), "Trust mapping has not been inserted") + assert.Equal(t, oidcTrust, tt.oidc.TGet(tt.tenantID), "Trust has not been inserted") // Check the returned URL u, err := url.Parse(got) @@ -307,8 +307,8 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { errAssert: assert.Error, }, { - name: "Trust mapping get error", - oidc: mocktrust.NewInMemRepository(mocktrust.WithGetError(errors.New("trust mapping not found"))), + name: "Trust get error", + oidc: mocktrust.NewInMemRepository(mocktrust.WithGetError(errors.New("trust not found"))), sessions: sessionmock.NewInMemRepository(sessionmock.WithState(validState)), stateID: stateID, code: code, @@ -354,7 +354,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { jwksURI, err := url.JoinPath(oidcServer.URL, "/.well-known/jwks.json") require.NoError(t, err) - localOIDCMapping := trustv1.Trust_builder{ + localOIDCTrust := trustv1.Trust_builder{ TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -364,7 +364,7 @@ func TestManager_FinaliseOIDCLogin(t *testing.T) { }.Build(), }.Build() - tt.oidc.TAdd(localOIDCMapping) + tt.oidc.TAdd(localOIDCTrust) m, err := session.NewManager(ctx, tt.cfg, @@ -562,7 +562,7 @@ func TestManager_LogoutEdgeCases(t *testing.T) { errAssert: assert.Error, }, { - name: "Trust mapping not found", + name: "Trust not found", sessionID: sessionID, setupMock: func(oidcs *mocktrust.Repository, sessions *sessionmock.Repository) { _ = sessions.StoreSession(context.Background(), session.Session{ diff --git a/internal/session/mock/repository.go b/internal/session/mock/repository.go index 6d5989c3..459dc75e 100644 --- a/internal/session/mock/repository.go +++ b/internal/session/mock/repository.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type RepositoryOption func(*Repository) diff --git a/internal/session/valkey/repository.go b/internal/session/valkey/repository.go index 87fa3e2f..92050b35 100644 --- a/internal/session/valkey/repository.go +++ b/internal/session/valkey/repository.go @@ -10,8 +10,8 @@ import ( slogctx "github.com/veqryn/slog-context" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/internal/session" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type ObjectType string diff --git a/modules/oidctrust/internal/sql/queries.sql b/modules/oidctrust/internal/sql/queries.sql index 5fdc1244..a50560fe 100644 --- a/modules/oidctrust/internal/sql/queries.sql +++ b/modules/oidctrust/internal/sql/queries.sql @@ -1,4 +1,4 @@ --- name: GetOIDCMapping :one +-- name: GetTrust :one SELECT issuer, blocked, @@ -8,7 +8,7 @@ SELECT FROM trust WHERE tenant_id = sqlc.arg(tenant_id); --- name: CreateOIDCMapping :exec +-- name: CreateTrust :exec INSERT INTO trust ( tenant_id, blocked, @@ -24,11 +24,11 @@ VALUES ( COALESCE(sqlc.arg(audiences)::text[], '{}'::text[]), sqlc.arg(client_id)); --- name: DeleteOIDCMapping :execrows +-- name: DeleteTrust :execrows DELETE FROM trust WHERE tenant_id = sqlc.arg(tenant_id); --- name: UpdateOIDCMapping :execrows +-- name: UpdateTrust :execrows UPDATE trust SET blocked = sqlc.arg(blocked), diff --git a/modules/oidctrust/internal/sql/queries/db.go b/modules/oidctrust/internal/sql/queries/db.go index c69f0c53..2b5c1c72 100644 --- a/modules/oidctrust/internal/sql/queries/db.go +++ b/modules/oidctrust/internal/sql/queries/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package queries diff --git a/modules/oidctrust/internal/sql/queries/queries.sql.go b/modules/oidctrust/internal/sql/queries/queries.sql.go index 7cea03e6..ec9a845c 100644 --- a/modules/oidctrust/internal/sql/queries/queries.sql.go +++ b/modules/oidctrust/internal/sql/queries/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: queries.sql package queries @@ -11,7 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const createOIDCMapping = `-- name: CreateOIDCMapping :exec +const createTrust = `-- name: CreateTrust :exec INSERT INTO trust ( tenant_id, blocked, @@ -28,7 +28,7 @@ VALUES ( $6) ` -type CreateOIDCMappingParams struct { +type CreateTrustParams struct { TenantID string `db:"tenant_id"` Blocked bool `db:"blocked"` Issuer string `db:"issuer"` @@ -37,8 +37,8 @@ type CreateOIDCMappingParams struct { ClientID pgtype.Text `db:"client_id"` } -func (q *Queries) CreateOIDCMapping(ctx context.Context, arg CreateOIDCMappingParams) error { - _, err := q.db.Exec(ctx, createOIDCMapping, +func (q *Queries) CreateTrust(ctx context.Context, arg CreateTrustParams) error { + _, err := q.db.Exec(ctx, createTrust, arg.TenantID, arg.Blocked, arg.Issuer, @@ -49,20 +49,20 @@ func (q *Queries) CreateOIDCMapping(ctx context.Context, arg CreateOIDCMappingPa return err } -const deleteOIDCMapping = `-- name: DeleteOIDCMapping :execrows +const deleteTrust = `-- name: DeleteTrust :execrows DELETE FROM trust WHERE tenant_id = $1 ` -func (q *Queries) DeleteOIDCMapping(ctx context.Context, tenantID string) (int64, error) { - result, err := q.db.Exec(ctx, deleteOIDCMapping, tenantID) +func (q *Queries) DeleteTrust(ctx context.Context, tenantID string) (int64, error) { + result, err := q.db.Exec(ctx, deleteTrust, tenantID) if err != nil { return 0, err } return result.RowsAffected(), nil } -const getOIDCMapping = `-- name: GetOIDCMapping :one +const getTrust = `-- name: GetTrust :one SELECT issuer, blocked, @@ -73,7 +73,7 @@ FROM trust WHERE tenant_id = $1 ` -type GetOIDCMappingRow struct { +type GetTrustRow struct { Issuer string `db:"issuer"` Blocked bool `db:"blocked"` JwksUri string `db:"jwks_uri"` @@ -81,9 +81,9 @@ type GetOIDCMappingRow struct { ClientID pgtype.Text `db:"client_id"` } -func (q *Queries) GetOIDCMapping(ctx context.Context, tenantID string) (GetOIDCMappingRow, error) { - row := q.db.QueryRow(ctx, getOIDCMapping, tenantID) - var i GetOIDCMappingRow +func (q *Queries) GetTrust(ctx context.Context, tenantID string) (GetTrustRow, error) { + row := q.db.QueryRow(ctx, getTrust, tenantID) + var i GetTrustRow err := row.Scan( &i.Issuer, &i.Blocked, @@ -94,7 +94,7 @@ func (q *Queries) GetOIDCMapping(ctx context.Context, tenantID string) (GetOIDCM return i, err } -const updateOIDCMapping = `-- name: UpdateOIDCMapping :execrows +const updateTrust = `-- name: UpdateTrust :execrows UPDATE trust SET blocked = $1, @@ -106,7 +106,7 @@ WHERE tenant_id = $6 ` -type UpdateOIDCMappingParams struct { +type UpdateTrustParams struct { Blocked bool `db:"blocked"` Issuer string `db:"issuer"` JwksUri string `db:"jwks_uri"` @@ -115,8 +115,8 @@ type UpdateOIDCMappingParams struct { TenantID string `db:"tenant_id"` } -func (q *Queries) UpdateOIDCMapping(ctx context.Context, arg UpdateOIDCMappingParams) (int64, error) { - result, err := q.db.Exec(ctx, updateOIDCMapping, +func (q *Queries) UpdateTrust(ctx context.Context, arg UpdateTrustParams) (int64, error) { + result, err := q.db.Exec(ctx, updateTrust, arg.Blocked, arg.Issuer, arg.JwksUri, diff --git a/modules/oidctrust/internal/sql/sql.go b/modules/oidctrust/internal/sql/sql.go index 8a2a184a..1e98da60 100644 --- a/modules/oidctrust/internal/sql/sql.go +++ b/modules/oidctrust/internal/sql/sql.go @@ -14,8 +14,8 @@ import ( trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" sessionmanager "github.com/openkcm/session-manager" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/modules/oidctrust/internal/sql/queries" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type Repository struct { @@ -30,10 +30,10 @@ func NewRepository(db sessionmanager.Database) *Repository { func (r *Repository) Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) { tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_mapping_sql") + ctx, span := tracer.Tracer("").Start(ctx, "get_trust_sql") defer span.End() - row, err := r.queries.GetOIDCMapping(ctx, tenantID) + row, err := r.queries.GetTrust(ctx, tenantID) if err != nil { span.RecordError(err) if errors.Is(err, pgx.ErrNoRows) { @@ -68,12 +68,12 @@ func (r *Repository) Get(ctx context.Context, tenantID string) (*trustv1.Trust, func (r *Repository) Create(ctx context.Context, trust *trustv1.Trust) error { tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "create_oidc_mapping_sql") + ctx, span := tracer.Tracer("").Start(ctx, "create_trust_sql") defer span.End() oidc := trust.GetOidc() - if err := r.queries.CreateOIDCMapping(ctx, queries.CreateOIDCMappingParams{ + if err := r.queries.CreateTrust(ctx, queries.CreateTrustParams{ TenantID: trust.GetTenantId(), Blocked: trust.GetBlocked(), Issuer: oidc.GetIssuer(), @@ -94,10 +94,10 @@ func (r *Repository) Create(ctx context.Context, trust *trustv1.Trust) error { func (r *Repository) Delete(ctx context.Context, tenantID string) error { tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "delete_oidc_mapping_sql") + ctx, span := tracer.Tracer("").Start(ctx, "delete_trust_sql") defer span.End() - affected, err := r.queries.DeleteOIDCMapping(ctx, tenantID) + affected, err := r.queries.DeleteTrust(ctx, tenantID) if err != nil { span.RecordError(err) return fmt.Errorf("executing sql query: %w", err) @@ -112,12 +112,12 @@ func (r *Repository) Delete(ctx context.Context, tenantID string) error { func (r *Repository) Update(ctx context.Context, trust *trustv1.Trust) error { tracer := otel.GetTracerProvider() - ctx, span := tracer.Tracer("").Start(ctx, "update_oidc_mapping_sql") + ctx, span := tracer.Tracer("").Start(ctx, "update_trust_sql") defer span.End() oidc := trust.GetOidc() - affected, err := r.queries.UpdateOIDCMapping(ctx, queries.UpdateOIDCMappingParams{ + affected, err := r.queries.UpdateTrust(ctx, queries.UpdateTrustParams{ Blocked: trust.GetBlocked(), Issuer: oidc.GetIssuer(), JwksUri: oidc.GetJwksUri(), diff --git a/modules/oidctrust/internal/sql/sql_test.go b/modules/oidctrust/internal/sql/sql_test.go index 7742818d..45650388 100644 --- a/modules/oidctrust/internal/sql/sql_test.go +++ b/modules/oidctrust/internal/sql/sql_test.go @@ -22,8 +22,8 @@ import ( sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/dbtest/postgrestest" - "github.com/openkcm/session-manager/pkg/serviceerr" sqltrust "github.com/openkcm/session-manager/modules/oidctrust/internal/sql" + "github.com/openkcm/session-manager/pkg/serviceerr" ) var dbPool sessionmanager.Database @@ -50,16 +50,16 @@ func TestMain(m *testing.M) { func TestRepository_Get(t *testing.T) { tests := []struct { - name string - tenantID string - wantMapping *trustv1.Trust - assertErr assert.ErrorAssertionFunc + name string + tenantID string + wantTrust *trustv1.Trust + assertErr assert.ErrorAssertionFunc }{ { - name: "Success", - tenantID: "tenant1-id", - wantMapping: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), Audiences: make([]string, 0)}.Build()}.Build(), - assertErr: assert.NoError, + name: "Success", + tenantID: "tenant1-id", + wantTrust: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), Audiences: make([]string, 0)}.Build()}.Build(), + assertErr: assert.NoError, }, { name: "Error does not exist", @@ -71,14 +71,14 @@ func TestRepository_Get(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := sqltrust.NewRepository(dbPool) - gotMapping, err := r.Get(t.Context(), tt.tenantID) + gotTrust, err := r.Get(t.Context(), tt.tenantID) if !tt.assertErr(t, err, fmt.Sprintf("Repository.Get() error %v", err)) || err != nil { - assert.Zerof(t, gotMapping, "Repository.Get() extected zero value if an error is returned, got %v", gotMapping) + assert.Zerof(t, gotTrust, "Repository.Get() extected zero value if an error is returned, got %v", gotTrust) return } - if diff := cmp.Diff(tt.wantMapping, gotMapping, protocmp.Transform()); diff != "" { - t.Fatalf("mapping not equal:\n%s", diff) + if diff := cmp.Diff(tt.wantTrust, gotTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) } }) } @@ -87,32 +87,32 @@ func TestRepository_Get(t *testing.T) { func TestRepository_Create(t *testing.T) { tests := []struct { name string - mapping *trustv1.Trust + trust *trustv1.Trust assertErr assert.ErrorAssertionFunc }{ { name: "Create succeeds", - mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Duplicate", - mapping: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("tenant1-id"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("url-one"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), assertErr: assert.Error, }, { name: "Create without JWKSURI and Audiences succeeds", - mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-2.example.com"), Audiences: []string{}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-2.example.com"), Audiences: []string{}}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Create without JWKSURI succeeds", - mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-3.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-jwks-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-3.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Create without Audiences succeeds", - mapping: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-4.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("tenant-id-create-without-aud-success"), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-success-4.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{}}.Build()}.Build(), assertErr: assert.NoError, }, } @@ -122,17 +122,17 @@ func TestRepository_Create(t *testing.T) { r := sqltrust.NewRepository(dbPool) // When - err := r.Create(t.Context(), tt.mapping) + err := r.Create(t.Context(), tt.trust) if !tt.assertErr(t, err, fmt.Sprintf("Repository.Create() error %v", err)) || err != nil { return } // Then - mapping, err := r.Get(t.Context(), tt.mapping.GetTenantId()) + gotCreated, err := r.Get(t.Context(), tt.trust.GetTenantId()) require.NoError(t, err) - if diff := cmp.Diff(tt.mapping, mapping, protocmp.Transform()); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) + if diff := cmp.Diff(tt.trust, gotCreated, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected trust in the database (-want, +got):\n%s", diff) } }) } @@ -140,9 +140,9 @@ func TestRepository_Create(t *testing.T) { func TestRepository_Delete(t *testing.T) { const tenantID = "tenant-id-delete-success" - mapping := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-delete.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + trust := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-delete.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() r := sqltrust.NewRepository(dbPool) - err := r.Create(t.Context(), mapping) + err := r.Create(t.Context(), trust) require.NoError(t, err, "Inserting test data") tests := []struct { @@ -168,65 +168,65 @@ func TestRepository_Delete(t *testing.T) { return } - gotMapping, err := r.Get(t.Context(), tt.tenantID) + gotTrust, err := r.Get(t.Context(), tt.tenantID) if !errors.Is(err, serviceerr.ErrNotFound) { - t.Error("The mapping is expected to be deleted") + t.Error("The trust is expected to be deleted") } - assert.Zero(t, gotMapping, "The mapping is expected to be deleted, instead a value is returned") + assert.Zero(t, gotTrust, "The trust is expected to be deleted, instead a value is returned") }) } } func TestRepository_Update(t *testing.T) { const tenantID = "tenant-id-update-success" - mapping := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-update.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() + trust := trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{Issuer: new("http://oidc-to-update.example.com"), JwksUri: new("jwks.example.com"), Audiences: []string{"cmk.example.com"}}.Build()}.Build() r := sqltrust.NewRepository(dbPool) - err := r.Create(t.Context(), mapping) + err := r.Create(t.Context(), trust) require.NoError(t, err, "Inserting test data") tests := []struct { name string - mapping *trustv1.Trust + trust *trustv1.Trust assertErr assert.ErrorAssertionFunc }{ { name: "Update succeeds", - mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Does not exist", - mapping: trustv1.Trust_builder{TenantId: new("does-not-exist"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new("does-not-exist"), JwksUri: new("jwks-updated.example.com"), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new("does-not-exist"), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new("does-not-exist"), JwksUri: new("jwks-updated.example.com"), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), assertErr: assert.Error, }, { name: "Update without JWKSURI and Audiences succeeds", - mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), Audiences: []string{}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), Audiences: []string{}}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Update without JWKSURI succeeds", - mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), Audiences: mapping.GetOidc().GetAudiences()}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), Audiences: trust.GetOidc().GetAudiences()}.Build()}.Build(), assertErr: assert.NoError, }, { name: "Update without Audiences succeeds", - mapping: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(mapping.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: []string{}}.Build()}.Build(), + trust: trustv1.Trust_builder{TenantId: new(tenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{Issuer: new(trust.GetOidc().GetIssuer()), JwksUri: new("jwks-updated.example.com"), Audiences: []string{}}.Build()}.Build(), assertErr: assert.NoError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := r.Update(t.Context(), tt.mapping) + err := r.Update(t.Context(), tt.trust) if !tt.assertErr(t, err, fmt.Sprintf("Repository.Update() error %v", err)) || err != nil { return } - gotMapping, err := r.Get(t.Context(), tt.mapping.GetTenantId()) + gotTrust, err := r.Get(t.Context(), tt.trust.GetTenantId()) require.NoError(t, err) - if diff := cmp.Diff(tt.mapping, gotMapping, protocmp.Transform()); diff != "" { - t.Fatalf("Unexpected mapping in the database (-want, +got):\n%s", diff) + if diff := cmp.Diff(tt.trust, gotTrust, protocmp.Transform()); diff != "" { + t.Fatalf("Unexpected trust in the database (-want, +got):\n%s", diff) } }) } diff --git a/modules/oidctrust/mocks/repository.go b/modules/oidctrust/mocks/repository.go index d200b8eb..e9774385 100644 --- a/modules/oidctrust/mocks/repository.go +++ b/modules/oidctrust/mocks/repository.go @@ -5,8 +5,8 @@ import ( trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/pkg/serviceerr" "github.com/openkcm/session-manager/modules/oidctrust" + "github.com/openkcm/session-manager/pkg/serviceerr" ) type RepositoryOption func(*Repository) @@ -17,8 +17,8 @@ type Repository struct { getErr, createErr, deleteErr, updateErr error } -func WithTrust(mapping *trustv1.Trust) RepositoryOption { - return func(r *Repository) { r.tenantTrust[mapping.GetTenantId()] = mapping } +func WithTrust(trust *trustv1.Trust) RepositoryOption { + return func(r *Repository) { r.tenantTrust[trust.GetTenantId()] = trust } } func WithGetError(err error) RepositoryOption { return func(r *Repository) { r.getErr = err } @@ -48,8 +48,8 @@ func NewInMemRepository(opts ...RepositoryOption) *Repository { } // TAdd is a helper method for tests to add a trust relationship. -func (r *Repository) TAdd(mapping *trustv1.Trust) { - r.tenantTrust[mapping.GetTenantId()] = mapping +func (r *Repository) TAdd(trust *trustv1.Trust) { + r.tenantTrust[trust.GetTenantId()] = trust } // TGet is a helper method for tests to get a trust relationship. @@ -61,17 +61,17 @@ func (r *Repository) Get(_ context.Context, tenantID string) (*trustv1.Trust, er if r.getErr != nil { return nil, r.getErr } - if mapping, ok := r.tenantTrust[tenantID]; ok { - return mapping, nil + if trust, ok := r.tenantTrust[tenantID]; ok { + return trust, nil } return nil, serviceerr.ErrNotFound } -func (r *Repository) Create(_ context.Context, mapping *trustv1.Trust) error { +func (r *Repository) Create(_ context.Context, trust *trustv1.Trust) error { if r.createErr != nil { return r.createErr } - r.tenantTrust[mapping.GetTenantId()] = mapping + r.tenantTrust[trust.GetTenantId()] = trust return nil } @@ -86,10 +86,10 @@ func (r *Repository) Delete(_ context.Context, tenantID string) error { return nil } -func (r *Repository) Update(_ context.Context, mapping *trustv1.Trust) error { +func (r *Repository) Update(_ context.Context, trust *trustv1.Trust) error { if r.updateErr != nil { return r.updateErr } - r.tenantTrust[mapping.GetTenantId()] = mapping + r.tenantTrust[trust.GetTenantId()] = trust return nil } diff --git a/modules/oidctrust/repository.go b/modules/oidctrust/repository.go index a114b37a..ae03a668 100644 --- a/modules/oidctrust/repository.go +++ b/modules/oidctrust/repository.go @@ -6,7 +6,7 @@ import ( trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" ) -// TrustRepository allows to read OIDC mapping data for a tenant stored in the context. +// TrustRepository allows to read OIDC trust data for a tenant stored in the context. type TrustRepository interface { Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) Create(ctx context.Context, trust *trustv1.Trust) error diff --git a/modules/oidctrust/trust.go b/modules/oidctrust/trust.go index fe6d2e9b..fa0d30de 100644 --- a/modules/oidctrust/trust.go +++ b/modules/oidctrust/trust.go @@ -11,31 +11,31 @@ import ( "github.com/openkcm/session-manager/pkg/serviceerr" ) -// ApplyMapping implements [sessionmanager.Trust]. -func (m *TrustModule) ApplyMapping(ctx context.Context, trust *trustv1.Trust) error { +// Apply implements [sessionmanager.Trust]. +func (m *TrustModule) Apply(ctx context.Context, trust *trustv1.Trust) error { if _, err := m.repository.Get(ctx, trust.GetTenantId()); err != nil { err = m.repository.Create(ctx, trust) if err != nil { - return fmt.Errorf("creating mapping for tenant: %w", err) + return fmt.Errorf("creating trust for tenant: %w", err) } } else { err = m.repository.Update(ctx, trust) if err != nil { - return fmt.Errorf("updating mapping for tenant: %w", err) + return fmt.Errorf("updating trust for tenant: %w", err) } } return nil } -// BlockMapping implements [sessionmanager.Trust]. -func (m *TrustModule) BlockMapping(ctx context.Context, tenantID string) error { +// Block implements [sessionmanager.Trust]. +func (m *TrustModule) Block(ctx context.Context, tenantID string) error { trust, err := m.repository.Get(ctx, tenantID) if err != nil { if errors.Is(err, serviceerr.ErrNotFound) { return nil } - return fmt.Errorf("getting mapping for tenant: %w", err) + return fmt.Errorf("getting trust for tenant: %w", err) } if trust.GetBlocked() { return nil @@ -46,29 +46,29 @@ func (m *TrustModule) BlockMapping(ctx context.Context, tenantID string) error { if errors.Is(err, serviceerr.ErrNotFound) { return nil } - return fmt.Errorf("updating mapping for blocking tenant: %w", err) + return fmt.Errorf("updating trust for blocking tenant: %w", err) } return nil } -// RemoveMapping implements [sessionmanager.Trust]. -func (m *TrustModule) RemoveMapping(ctx context.Context, tenantID string) error { +// Remove implements [sessionmanager.Trust]. +func (m *TrustModule) Remove(ctx context.Context, tenantID string) error { err := m.repository.Delete(ctx, tenantID) if err != nil { - return fmt.Errorf("deleting mapping for tenant: %w", err) + return fmt.Errorf("deleting trust for tenant: %w", err) } return nil } -// UnblockMapping implements [sessionmanager.Trust]. -func (m *TrustModule) UnblockMapping(ctx context.Context, tenantID string) error { +// Unblock implements [sessionmanager.Trust]. +func (m *TrustModule) Unblock(ctx context.Context, tenantID string) error { trust, err := m.repository.Get(ctx, tenantID) if err != nil { if errors.Is(err, serviceerr.ErrNotFound) { return nil } - return fmt.Errorf("getting mapping for tenant: %w", err) + return fmt.Errorf("getting trust for tenant: %w", err) } if !trust.GetBlocked() { return nil @@ -78,7 +78,7 @@ func (m *TrustModule) UnblockMapping(ctx context.Context, tenantID string) error if errors.Is(err, serviceerr.ErrNotFound) { return nil } - return fmt.Errorf("updating mapping for unblocking tenant: %w", err) + return fmt.Errorf("updating trust for unblocking tenant: %w", err) } return nil } diff --git a/modules/oidctrust/trust_test.go b/modules/oidctrust/trust_test.go index 76d2b19c..2c67c69f 100644 --- a/modules/oidctrust/trust_test.go +++ b/modules/oidctrust/trust_test.go @@ -40,13 +40,13 @@ func TestMain(m *testing.M) { os.Exit(code) } -func TestService_ApplyMapping(t *testing.T) { +func TestService_Apply(t *testing.T) { ctx := t.Context() t.Run("success if", func(t *testing.T) { t.Run("the trust does not exist", func(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -59,19 +59,19 @@ func TestService_ApplyMapping(t *testing.T) { wrapper := &RepoWrapper{Repo: repo} subj := oidctrust.NewModule(wrapper) - err := subj.ApplyMapping(ctx, expMapping) + err := subj.Apply(ctx, expTrust) assert.NoError(t, err) - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - if diff := cmp.Diff(expMapping, actMapping, protocmp.Transform()); diff != "" { - t.Fatalf("mapping not equal:\n%s", diff) + if diff := cmp.Diff(expTrust, actTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) } }) t.Run("the trust exists", func(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -84,26 +84,26 @@ func TestService_ApplyMapping(t *testing.T) { wrapper := &RepoWrapper{Repo: repo} subj := oidctrust.NewModule(wrapper) - err := subj.ApplyMapping(ctx, expMapping) + err := subj.Apply(ctx, expTrust) assert.NoError(t, err) - expUpdatedMapping := trustv1.Trust_builder{ + expUpdatedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ - Issuer: new(expMapping.GetOidc().GetIssuer()), + Issuer: new(expTrust.GetOidc().GetIssuer()), JwksUri: new("http://updated-jwks.example.com"), Audiences: []string{requestURI, "http://new-aud.example.com"}, }.Build(), }.Build() - err = subj.ApplyMapping(ctx, expUpdatedMapping) + err = subj.Apply(ctx, expUpdatedTrust) assert.NoError(t, err) - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - if diff := cmp.Diff(expUpdatedMapping, actMapping, protocmp.Transform()); diff != "" { - t.Fatalf("mapping not equal:\n%s", diff) + if diff := cmp.Diff(expUpdatedTrust, actTrust, protocmp.Transform()); diff != "" { + t.Fatalf("trust not equal:\n%s", diff) } }) }) @@ -111,7 +111,7 @@ func TestService_ApplyMapping(t *testing.T) { t.Run("should return error if", func(t *testing.T) { t.Run("Create returns an error", func(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(uuid.Must(uuid.NewV4()).String()), @@ -124,13 +124,13 @@ func TestService_ApplyMapping(t *testing.T) { noOfCalls := 0 wrapper.MockCreate = func(ctx context.Context, trust *trustv1.Trust) error { assert.Equal(t, expTenantID, trust.GetTenantId()) - assert.Equal(t, expMapping, trust) + assert.Equal(t, expTrust, trust) noOfCalls++ return assert.AnError } subj := oidctrust.NewModule(wrapper) - err := subj.ApplyMapping(ctx, expMapping) + err := subj.Apply(ctx, expTrust) assert.ErrorIs(t, err, assert.AnError) assert.Equal(t, 1, noOfCalls) @@ -138,7 +138,7 @@ func TestService_ApplyMapping(t *testing.T) { t.Run("Update returns an error", func(t *testing.T) { expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(uuid.Must(uuid.NewV4()).String()), @@ -151,15 +151,15 @@ func TestService_ApplyMapping(t *testing.T) { noOfCalls := 0 wrapper.MockUpdate = func(ctx context.Context, trust *trustv1.Trust) error { assert.Equal(t, expTenantID, trust.GetTenantId()) - assert.Equal(t, expMapping, trust) + assert.Equal(t, expTrust, trust) noOfCalls++ return assert.AnError } subj := oidctrust.NewModule(wrapper) - err := subj.ApplyMapping(ctx, expMapping) + err := subj.Apply(ctx, expTrust) assert.NoError(t, err) - err = subj.ApplyMapping(ctx, expMapping) + err = subj.Apply(ctx, expTrust) assert.ErrorIs(t, err, assert.AnError) assert.Equal(t, 1, noOfCalls) @@ -167,14 +167,14 @@ func TestService_ApplyMapping(t *testing.T) { }) } -func TestService_BlockMapping(t *testing.T) { +func TestService_Block(t *testing.T) { ctx := t.Context() t.Run("success if ", func(t *testing.T) { t.Run("the trust is unblocked", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trustv1.Trust_builder{ + expUnblockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -185,28 +185,28 @@ func TestService_BlockMapping(t *testing.T) { }.Build() wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expUnblockedMapping) + err := wrapper.Repo.Create(ctx, expUnblockedTrust) require.NoError(t, err) subj := oidctrust.NewModule(wrapper) // when - err = subj.BlockMapping(ctx, expTenantID) + err = subj.Block(ctx, expTenantID) // then assert.NoError(t, err) - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - assert.True(t, actMapping.GetBlocked()) - assert.Equal(t, expUnblockedMapping.GetOidc().GetIssuer(), actMapping.GetOidc().GetIssuer()) - assert.Equal(t, expUnblockedMapping.GetOidc().GetAudiences(), actMapping.GetOidc().GetAudiences()) - assert.Equal(t, expUnblockedMapping.GetOidc().GetJwksUri(), actMapping.GetOidc().GetJwksUri()) + assert.True(t, actTrust.GetBlocked()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetIssuer(), actTrust.GetOidc().GetIssuer()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetAudiences(), actTrust.GetOidc().GetAudiences()) + assert.Equal(t, expUnblockedTrust.GetOidc().GetJwksUri(), actTrust.GetOidc().GetJwksUri()) }) t.Run("the trust is blocked then it should not call Update", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trustv1.Trust_builder{ + expBlockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -216,7 +216,7 @@ func TestService_BlockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -227,20 +227,20 @@ func TestService_BlockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.BlockMapping(t.Context(), expTenantID) + err = subj.Block(t.Context(), expTenantID) // then assert.NoError(t, err) assert.Equal(t, 0, noOfUpdateCalls) - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) + assert.Equal(t, expBlockedTrust, actTrust) }) t.Run("the trust is not found during the Update", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trustv1.Trust_builder{ + expBlockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -250,7 +250,7 @@ func TestService_BlockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -264,7 +264,7 @@ func TestService_BlockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.BlockMapping(t.Context(), expTenantID) + err = subj.Block(t.Context(), expTenantID) // then assert.NoError(t, err) @@ -278,7 +278,7 @@ func TestService_BlockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err := subj.BlockMapping(t.Context(), expTenantID) + err := subj.Block(t.Context(), expTenantID) // then assert.NoError(t, err) @@ -302,7 +302,7 @@ func TestService_BlockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err := subj.BlockMapping(t.Context(), expTenantID) + err := subj.Block(t.Context(), expTenantID) // then assert.ErrorIs(t, err, assert.AnError) @@ -312,7 +312,7 @@ func TestService_BlockMapping(t *testing.T) { t.Run("if Update returns an error", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -322,7 +322,7 @@ func TestService_BlockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expMapping) + err := repoWrapper.Repo.Create(ctx, expTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -334,27 +334,27 @@ func TestService_BlockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.BlockMapping(t.Context(), expTenantID) + err = subj.Block(t.Context(), expTenantID) // then assert.ErrorIs(t, err, assert.AnError) assert.Equal(t, 1, noOfUpdateCalls) - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - assert.Equal(t, expMapping, actMapping) + assert.Equal(t, expTrust, actTrust) }) }) } -func TestService_UnblockMapping(t *testing.T) { +func TestService_Unblock(t *testing.T) { ctx := t.Context() t.Run("success if ", func(t *testing.T) { t.Run("the trust is blocked", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trustv1.Trust_builder{ + expBlockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -365,28 +365,28 @@ func TestService_UnblockMapping(t *testing.T) { }.Build() wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expBlockedMapping) + err := wrapper.Repo.Create(ctx, expBlockedTrust) require.NoError(t, err) subj := oidctrust.NewModule(wrapper) // when - err = subj.UnblockMapping(t.Context(), expTenantID) + err = subj.Unblock(t.Context(), expTenantID) // then assert.NoError(t, err) - actMapping, err := wrapper.Repo.Get(ctx, expTenantID) + actTrust, err := wrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - assert.False(t, actMapping.GetBlocked()) - assert.Equal(t, expBlockedMapping.GetOidc().GetIssuer(), actMapping.GetOidc().GetIssuer()) - assert.Equal(t, expBlockedMapping.GetOidc().GetAudiences(), actMapping.GetOidc().GetAudiences()) - assert.Equal(t, expBlockedMapping.GetOidc().GetJwksUri(), actMapping.GetOidc().GetJwksUri()) + assert.False(t, actTrust.GetBlocked()) + assert.Equal(t, expBlockedTrust.GetOidc().GetIssuer(), actTrust.GetOidc().GetIssuer()) + assert.Equal(t, expBlockedTrust.GetOidc().GetAudiences(), actTrust.GetOidc().GetAudiences()) + assert.Equal(t, expBlockedTrust.GetOidc().GetJwksUri(), actTrust.GetOidc().GetJwksUri()) }) t.Run("the trust is unblocked then it should not call Update", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trustv1.Trust_builder{ + expUnblockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(false), Oidc: oidcv1.OIDC_builder{ @@ -396,7 +396,7 @@ func TestService_UnblockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expUnblockedMapping) + err := repoWrapper.Repo.Create(ctx, expUnblockedTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -407,20 +407,20 @@ func TestService_UnblockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.UnblockMapping(t.Context(), expTenantID) + err = subj.Unblock(t.Context(), expTenantID) // then assert.NoError(t, err) assert.Equal(t, 0, noOfUpdateCalls) - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantID) + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantID) assert.NoError(t, err) - assert.False(t, actMapping.GetBlocked()) + assert.False(t, actTrust.GetBlocked()) }) t.Run("the trust is not found during the Update", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expUnblockedMapping := trustv1.Trust_builder{ + expUnblockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -430,7 +430,7 @@ func TestService_UnblockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expUnblockedMapping) + err := repoWrapper.Repo.Create(ctx, expUnblockedTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -444,7 +444,7 @@ func TestService_UnblockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.UnblockMapping(t.Context(), expTenantID) + err = subj.Unblock(t.Context(), expTenantID) // then assert.NoError(t, err) @@ -458,7 +458,7 @@ func TestService_UnblockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err := subj.UnblockMapping(t.Context(), expTenantID) + err := subj.Unblock(t.Context(), expTenantID) // then assert.NoError(t, err) @@ -481,7 +481,7 @@ func TestService_UnblockMapping(t *testing.T) { subj := oidctrust.NewModule(mockRepo) // when - err := subj.UnblockMapping(t.Context(), expTenantID) + err := subj.Unblock(t.Context(), expTenantID) // then assert.ErrorIs(t, err, assert.AnError) @@ -491,7 +491,7 @@ func TestService_UnblockMapping(t *testing.T) { t.Run("if Update returns an error", func(t *testing.T) { // given expTenantIDtoUpdate := uuid.Must(uuid.NewV4()).String() - expBlockedMapping := trustv1.Trust_builder{ + expBlockedTrust := trustv1.Trust_builder{ TenantId: new(expTenantIDtoUpdate), Blocked: new(true), Oidc: oidcv1.OIDC_builder{ @@ -501,7 +501,7 @@ func TestService_UnblockMapping(t *testing.T) { }.Build(), }.Build() repoWrapper := &RepoWrapper{Repo: repo} - err := repoWrapper.Repo.Create(ctx, expBlockedMapping) + err := repoWrapper.Repo.Create(ctx, expBlockedTrust) require.NoError(t, err) noOfUpdateCalls := 0 @@ -513,27 +513,27 @@ func TestService_UnblockMapping(t *testing.T) { subj := oidctrust.NewModule(repoWrapper) // when - err = subj.UnblockMapping(t.Context(), expTenantIDtoUpdate) + err = subj.Unblock(t.Context(), expTenantIDtoUpdate) // then assert.ErrorIs(t, err, assert.AnError) assert.Equal(t, 1, noOfUpdateCalls) - actMapping, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) + actTrust, err := repoWrapper.Repo.Get(ctx, expTenantIDtoUpdate) assert.NoError(t, err) - assert.Equal(t, expBlockedMapping, actMapping) + assert.Equal(t, expBlockedTrust, actTrust) }) }) } -func TestService_RemoveMapping(t *testing.T) { +func TestService_Remove(t *testing.T) { ctx := t.Context() t.Run("success if", func(t *testing.T) { t.Run("the trust exists", func(t *testing.T) { // given expTenantID := uuid.Must(uuid.NewV4()).String() - expMapping := trustv1.Trust_builder{ + expTrust := trustv1.Trust_builder{ TenantId: new(expTenantID), Oidc: oidcv1.OIDC_builder{ Issuer: new(uuid.Must(uuid.NewV4()).String()), @@ -543,13 +543,13 @@ func TestService_RemoveMapping(t *testing.T) { }.Build() wrapper := &RepoWrapper{Repo: repo} - err := wrapper.Repo.Create(ctx, expMapping) + err := wrapper.Repo.Create(ctx, expTrust) require.NoError(t, err) subj := oidctrust.NewModule(wrapper) // when - err = subj.RemoveMapping(ctx, expTenantID) + err = subj.Remove(ctx, expTenantID) // then assert.NoError(t, err) @@ -568,7 +568,7 @@ func TestService_RemoveMapping(t *testing.T) { subj := oidctrust.NewModule(wrapper) // when - err := subj.RemoveMapping(ctx, expTenantID) + err := subj.Remove(ctx, expTenantID) // then assert.Error(t, err) @@ -589,7 +589,7 @@ func TestService_RemoveMapping(t *testing.T) { subj := oidctrust.NewModule(wrapper) // when - err := subj.RemoveMapping(ctx, expTenantID) + err := subj.Remove(ctx, expTenantID) // then assert.ErrorIs(t, err, assert.AnError) diff --git a/trust.go b/trust.go index 7d69d1e2..4f8eb0fc 100644 --- a/trust.go +++ b/trust.go @@ -7,18 +7,18 @@ import ( ) type Trust interface { - // ApplyMapping applies and stores the provided Trust. - ApplyMapping(ctx context.Context, trust *trustv1.Trust) error - // BlockMapping sets the Blocked flag to true for the OIDC mapping associated with the given tenantID. - // If the mapping is already blocked, it does nothing. - // Returns an error if the mapping cannot be retrieved or updated. - BlockMapping(ctx context.Context, tenantID string) error - // RemoveMapping removes the specified mapping from the trust. - RemoveMapping(ctx context.Context, tenantID string) error - // UnblockMapping sets the Blocked flag to false for the OIDC mapping associated with the given tenantID. - // If the mapping is not blocked, it does nothing. - // Returns an error if the mapping cannot be retrieved or updated. - UnblockMapping(ctx context.Context, tenantID string) error + // Apply applies and stores the provided Trust. + Apply(ctx context.Context, trust *trustv1.Trust) error + // Block sets the Blocked flag to true for the trust associated with the given tenantID. + // If the trust is already blocked, it does nothing. + // Returns an error if the trust cannot be retrieved or updated. + Block(ctx context.Context, tenantID string) error + // Remove removes the trust for the given tenantID. + Remove(ctx context.Context, tenantID string) error + // Unblock sets the Blocked flag to false for the trust associated with the given tenantID. + // If the trust is not blocked, it does nothing. + // Returns an error if the trust cannot be retrieved or updated. + Unblock(ctx context.Context, tenantID string) error // Get returns a trust message with optional extensions set. Get(ctx context.Context, tenantID string) (*trustv1.Trust, error) } From 3b9944125707903b9228801dc8192885d5dcca0a Mon Sep 17 00:00:00 2001 From: Danylo Shevchenko Date: Tue, 26 May 2026 09:01:55 +0200 Subject: [PATCH 4/5] feat: grpc app module Signed-off-by: Danylo Shevchenko --- .../session-manager/templates/configmap.yaml | 9 + charts/session-manager/values-dev.yaml | 16 ++ charts/session-manager/values.yaml | 16 ++ config.yaml | 22 +- context.go | 70 ++++-- context_test.go | 237 ++++++++++++++++++ internal/business/business.go | 58 +---- internal/business/business_test.go | 17 -- internal/business/housekeeper.go | 12 +- internal/business/server/grpc_server.go | 59 ----- internal/business/server/grpc_server_test.go | 52 ---- internal/config/config.go | 71 +++++- internal/config/config_test.go | 86 +++++++ internal/config/load.go | 13 + internal/grpc/options.go | 23 -- internal/sessionwiring/sessionwiring.go | 127 ++++------ internal/sessionwiring/sessionwiring_test.go | 147 ----------- modules/app/grpcserver/module.go | 160 ++++++++++++ modules/app/grpcserver/module_test.go | 131 ++++++++++ modules/credentials/oauth2/module.go | 90 +++++++ modules/credentials/oauth2/module_test.go | 81 ++++++ .../grpc/session}/import_test.go | 2 +- modules/grpc/session/module.go | 113 +++++++++ modules/grpc/session/options.go | 23 ++ .../grpc/session/server.go | 44 ++-- .../grpc/session/server_test.go | 96 +++---- .../grpc/session}/violations.go | 2 +- modules/grpc/trustmapping/import_test.go | 12 + modules/grpc/trustmapping/module.go | 60 +++++ modules/grpc/trustmapping/module_test.go | 83 ++++++ .../grpc/trustmapping/server.go | 24 +- .../grpc/trustmapping/server_test.go | 38 +-- modules/sessionstore/valkey/module.go | 98 ++++++++ modules/sessionstore/valkey/module_test.go | 27 ++ modules/standard/imports.go | 5 + 35 files changed, 1579 insertions(+), 545 deletions(-) delete mode 100644 internal/business/server/grpc_server.go delete mode 100644 internal/business/server/grpc_server_test.go delete mode 100644 internal/grpc/options.go delete mode 100644 internal/sessionwiring/sessionwiring_test.go create mode 100644 modules/app/grpcserver/module.go create mode 100644 modules/app/grpcserver/module_test.go create mode 100644 modules/credentials/oauth2/module.go create mode 100644 modules/credentials/oauth2/module_test.go rename {internal/grpc => modules/grpc/session}/import_test.go (94%) create mode 100644 modules/grpc/session/module.go create mode 100644 modules/grpc/session/options.go rename internal/grpc/session.go => modules/grpc/session/server.go (86%) rename internal/grpc/session_test.go => modules/grpc/session/server_test.go (89%) rename {internal/grpc => modules/grpc/session}/violations.go (77%) create mode 100644 modules/grpc/trustmapping/import_test.go create mode 100644 modules/grpc/trustmapping/module.go create mode 100644 modules/grpc/trustmapping/module_test.go rename internal/grpc/trustmapping.go => modules/grpc/trustmapping/server.go (78%) rename internal/grpc/trust_test.go => modules/grpc/trustmapping/server_test.go (93%) create mode 100644 modules/sessionstore/valkey/module.go create mode 100644 modules/sessionstore/valkey/module_test.go diff --git a/charts/session-manager/templates/configmap.yaml b/charts/session-manager/templates/configmap.yaml index acdc66fd..2ed3ad58 100644 --- a/charts/session-manager/templates/configmap.yaml +++ b/charts/session-manager/templates/configmap.yaml @@ -43,6 +43,12 @@ data: database: {{- toYaml .database | nindent 6 }} + trust: + {{- toYaml .trust | nindent 6 }} + + credentials: + {{- toYaml .credentials | nindent 6 }} + valkey: {{- toYaml .valkey | nindent 6 }} @@ -54,4 +60,7 @@ data: housekeeper: {{- toYaml .housekeeper | nindent 6 }} + + apps: + {{- toYaml .apps | nindent 6 }} {{- end }} diff --git a/charts/session-manager/values-dev.yaml b/charts/session-manager/values-dev.yaml index 6d80faec..7522ab14 100644 --- a/charts/session-manager/values-dev.yaml +++ b/charts/session-manager/values-dev.yaml @@ -272,6 +272,7 @@ config: enabled: false database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -284,7 +285,14 @@ config: source: embedded value: secret + trust: + module: trust.module.oidc + + credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: valkey-headless.session-manager.svc.cluster.local:6379 @@ -335,9 +343,17 @@ config: value: my-csrf-secret-at-least-thirty-two-bits-size migrate: + module: trust.migration.module.oidc source: file:///sql housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + + apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping diff --git a/charts/session-manager/values.yaml b/charts/session-manager/values.yaml index 69fd0013..5ef91b5b 100644 --- a/charts/session-manager/values.yaml +++ b/charts/session-manager/values.yaml @@ -280,6 +280,7 @@ config: enabled: false database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -292,7 +293,14 @@ config: source: embedded value: secret + trust: + module: trust.module.oidc + + credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: host.ns.svc.cluster.local @@ -348,9 +356,17 @@ config: value: my-csrf-secret-at-least-thirty-two-bits-size migrate: + module: trust.migration.module.oidc source: file:///sql housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + + apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping diff --git a/config.yaml b/config.yaml index a724449f..703c79a3 100644 --- a/config.yaml +++ b/config.yaml @@ -181,6 +181,7 @@ telemetry: jsonPath: "$.server-ca" database: + module: database.module.pgxpool name: session_manager port: 5432 host: @@ -193,7 +194,18 @@ database: source: embedded value: secret +trust: + module: trust.module.oidc + +migrate: + module: trust.migration.module.oidc + source: file://./sql + +credentials: + module: credentials.module.oauth2 + valkey: + module: sessionstore.module.valkey host: source: embedded value: localhost:6379 @@ -248,10 +260,14 @@ sessionManager: source: embedded value: my-csrf-secret-at-least-thirty-two-bits-size -migrate: - source: file://./sql - housekeeper: triggerInterval: 10m concurrencyLimit: 10 tokenRefreshTriggerInterval: 15m + +apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + - module: service.module.grpc.trustmapping diff --git a/context.go b/context.go index 9e150bf9..45b5f503 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "reflect" + "slices" slogctx "github.com/veqryn/slog-context" ) @@ -14,15 +15,17 @@ type Context struct { //nolint:containedctx context.Context - mods map[string]Module - apps map[string]App + mods map[string]Module + modOrder []string + apps map[string]App } func (c *Context) cloneWithParent(parent context.Context) *Context { return &Context{ - Context: parent, - mods: c.mods, - apps: c.apps, + Context: parent, + mods: c.mods, + modOrder: c.modOrder, + apps: c.apps, } } @@ -33,23 +36,29 @@ func (c *Context) WithValue(key, val any) *Context { func NewContext(ctx context.Context) (*Context, context.CancelCauseFunc) { ctx, cancelCause := context.WithCancelCause(ctx) c := &Context{ - Context: ctx, - mods: make(map[string]Module), - apps: make(map[string]App), + Context: ctx, + mods: make(map[string]Module), + modOrder: nil, + apps: make(map[string]App), } return c, func(cause error) { cancelCause(cause) - for name, mod := range c.mods { - if closer, ok := mod.(io.Closer); ok { + for name, app := range c.apps { + if closer, ok := app.(io.Closer); ok { if err := closer.Close(); err != nil { - slogctx.Error(c, "failed to close a module", "module", name, "error", err) + slogctx.Error(c, "failed to close an app", "app", name, "error", err) } } } - for name, app := range c.apps { - if closer, ok := app.(io.Closer); ok { + for _, v := range slices.Backward(c.modOrder) { + id := v + mod, ok := c.mods[id] + if !ok { + continue + } + if closer, ok := mod.(io.Closer); ok { if err := closer.Close(); err != nil { - slogctx.Error(c, "failed to close an app", "app", name, "error", err) + slogctx.Error(c, "failed to close a module", "module", id, "error", err) } } } @@ -78,32 +87,40 @@ func (c *Context) GetApp(id string) (App, error) { } func (c *Context) LoadModule(cfg ExtensionConfig) (Module, error) { + before := len(c.modOrder) mod, modInfo, err := c.instantiate(cfg) if err != nil { + c.unloadModulesAfter(before) return nil, err } if _, ok := c.mods[modInfo.ID]; ok { + c.unloadModulesAfter(before) return nil, errors.New("module has already been loaded") } c.mods[modInfo.ID] = mod + c.modOrder = append(c.modOrder, modInfo.ID) return mod, nil } func (c *Context) LoadApp(cfg ExtensionConfig) (App, error) { + before := len(c.modOrder) mod, modInfo, err := c.instantiate(cfg) if err != nil { + c.unloadModulesAfter(before) return nil, err } app, ok := mod.(App) if !ok { + c.unloadModulesAfter(before) return nil, fmt.Errorf("module %q does not implement the App interface", modInfo.ID) } if _, ok := c.apps[modInfo.ID]; ok { + c.unloadModulesAfter(before) return nil, errors.New("app has already been loaded") } @@ -112,6 +129,31 @@ func (c *Context) LoadApp(cfg ExtensionConfig) (App, error) { return app, nil } +// unloadModulesAfter rolls back any modules appended to modOrder at or after +// the snapshot index. Modules are closed in reverse load order. It is the +// recovery path for a failed LoadModule or LoadApp call: every successfully +// loaded child module is closed and removed from the registry before the +// error surfaces to the caller. +func (c *Context) unloadModulesAfter(snapshot int) { + if snapshot >= len(c.modOrder) { + return + } + for i := len(c.modOrder) - 1; i >= snapshot; i-- { + id := c.modOrder[i] + mod, ok := c.mods[id] + if !ok { + continue + } + if closer, ok := mod.(io.Closer); ok { + if err := closer.Close(); err != nil { + slogctx.Error(c, "failed to close a module during rollback", "module", id, "error", err) + } + } + delete(c.mods, id) + } + c.modOrder = c.modOrder[:snapshot] +} + // instantiate resolves cfg.Module(), calls New(), unmarshals the extension, and // runs Provision if the resulting instance is a Provisioner. It is shared by // LoadModule and LoadApp. diff --git a/context_test.go b/context_test.go index c3ba86ca..4c263335 100644 --- a/context_test.go +++ b/context_test.go @@ -352,6 +352,243 @@ func TestNewContext_CancelClosesApps(t *testing.T) { assert.True(t, cam.closed, "Close() should be called on apps when context is cancelled") } +// childLoadingProvisioner is a Module whose Provision loads the configured +// child module IDs, in order, via ctx.LoadModule. If failAfter is non-negative +// it returns an error immediately after loading that many children. +type childLoadingProvisioner struct { + stubModule + + childIDs []string + failAfter int // -1 = never fail + failReason string // error text used when failAfter triggers +} + +func (m *childLoadingProvisioner) Provision(ctx *sessionmanager.Context) error { + for i, id := range m.childIDs { + if m.failAfter >= 0 && i == m.failAfter { + return errors.New(m.failReason) + } + if _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}); err != nil { + return err + } + } + if m.failAfter >= 0 && m.failAfter >= len(m.childIDs) { + return errors.New(m.failReason) + } + return nil +} + +// childLoadingApp is an App whose Provision loads the given child modules +// before the framework registers it as an app. +type childLoadingApp struct { + appModule + childLoadingProvisioner +} + +func (a *childLoadingApp) Module() sessionmanager.ModuleInfo { + return a.appModule.Module() +} + +func (a *childLoadingApp) Provision(ctx *sessionmanager.Context) error { + return a.childLoadingProvisioner.Provision(ctx) +} + +// orderRecorder is shared across closableOrderModule instances so tests can +// observe Close ordering across the framework's reverse-load-order shutdown. +type orderRecorder struct { + closes []string +} + +// closableOrderModule appends its ID to the recorder when Close is called. +type closableOrderModule struct { + stubModule + + rec *orderRecorder +} + +func (m *closableOrderModule) Close() error { + m.rec.closes = append(m.rec.closes, m.id) + return nil +} + +func TestLoadModule_ChildLoadFailureRollsBackEarlierSiblings(t *testing.T) { + parentID := uniqueID(t, "parent") + child1ID := uniqueID(t, "child1") + // child2ID is intentionally unregistered so its load fails. + child2ID := uniqueID(t, "child2-missing") + + c1 := &closableModule{stubModule: stubModule{id: child1ID}} + sessionmanager.RegisterModule(&customNewModule{ + id: child1ID, + newFn: func() sessionmanager.Module { return c1 }, + }) + + parent := &childLoadingProvisioner{ + stubModule: stubModule{id: parentID}, + childIDs: []string{child1ID, child2ID}, + failAfter: -1, + } + sessionmanager.RegisterModule(&customNewModule{ + id: parentID, + newFn: func() sessionmanager.Module { return parent }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: parentID}) + require.Error(t, err) + assert.True(t, c1.closed, "earlier sibling must be closed during rollback") + + // child1 must be removed from the registry. + _, err = ctx.GetModule(child1ID) + require.Error(t, err) + + // parent itself was never registered (its Provision failed). + _, err = ctx.GetModule(parentID) + require.Error(t, err) +} + +func TestLoadApp_ProvisionErrorRollsBackChildren(t *testing.T) { + appID := uniqueID(t, "app") + child1ID := uniqueID(t, "ch1") + child2ID := uniqueID(t, "ch2") + + c1 := &closableModule{stubModule: stubModule{id: child1ID}} + c2 := &closableModule{stubModule: stubModule{id: child2ID}} + + sessionmanager.RegisterModule(&customNewModule{ + id: child1ID, + newFn: func() sessionmanager.Module { return c1 }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: child2ID, + newFn: func() sessionmanager.Module { return c2 }, + }) + + app := &childLoadingApp{ + appModule: appModule{stubModule: stubModule{id: appID}}, + childLoadingProvisioner: childLoadingProvisioner{ + childIDs: []string{child1ID, child2ID}, + failAfter: 2, // fail after both children loaded + failReason: "app provision boom", + }, + } + sessionmanager.RegisterModule(&customNewModule{ + id: appID, + newFn: func() sessionmanager.Module { return app }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadApp(&simpleExtensionConfig{moduleID: appID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "app provision boom") + + assert.True(t, c1.closed, "child1 must be closed during rollback") + assert.True(t, c2.closed, "child2 must be closed during rollback") + + // Neither child is in the registry. + _, err = ctx.GetModule(child1ID) + require.Error(t, err) + _, err = ctx.GetModule(child2ID) + require.Error(t, err) + + // The app itself was never registered. + _, err = ctx.GetApp(appID) + require.Error(t, err) +} + +func TestLoadModule_NonCloserChildIsRemovedOnRollback(t *testing.T) { + parentID := uniqueID(t, "parent-noncloser") + plainID := uniqueID(t, "plain-child") + missingID := uniqueID(t, "missing-child") + + sessionmanager.RegisterModule(&customNewModule{ + id: plainID, + newFn: func() sessionmanager.Module { return &stubModule{id: plainID} }, + }) + + parent := &childLoadingProvisioner{ + stubModule: stubModule{id: parentID}, + childIDs: []string{plainID, missingID}, + failAfter: -1, + } + sessionmanager.RegisterModule(&customNewModule{ + id: parentID, + newFn: func() sessionmanager.Module { return parent }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: parentID}) + require.Error(t, err) + + // Non-closer child must still be removed from the registry. + _, err = ctx.GetModule(plainID) + require.Error(t, err) +} + +func TestNewContext_CloseInReverseLoadOrder(t *testing.T) { + rec := &orderRecorder{} + + idA := uniqueID(t, "ord-A") + idB := uniqueID(t, "ord-B") + idC := uniqueID(t, "ord-C") + + for _, id := range []string{idA, idB, idC} { + sessionmanager.RegisterModule(&customNewModule{ + id: id, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: id}, rec: rec} }, + }) + } + + ctx, cancel := sessionmanager.NewContext(t.Context()) + + for _, id := range []string{idA, idB, idC} { + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + } + + cancel(nil) + + require.Equal(t, []string{idC, idB, idA}, rec.closes, + "modules must be closed in reverse load order") +} + +func TestNewContext_CloseSkipsNonClosersInReverseOrder(t *testing.T) { + rec := &orderRecorder{} + + idA := uniqueID(t, "mix-A") // closer + idB := uniqueID(t, "mix-B") // non-closer + idC := uniqueID(t, "mix-C") // closer + + sessionmanager.RegisterModule(&customNewModule{ + id: idA, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: idA}, rec: rec} }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: idB, + newFn: func() sessionmanager.Module { return &stubModule{id: idB} }, + }) + sessionmanager.RegisterModule(&customNewModule{ + id: idC, + newFn: func() sessionmanager.Module { return &closableOrderModule{stubModule: stubModule{id: idC}, rec: rec} }, + }) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + for _, id := range []string{idA, idB, idC} { + _, err := ctx.LoadModule(&simpleExtensionConfig{moduleID: id}) + require.NoError(t, err) + } + + cancel(nil) + require.Equal(t, []string{idC, idA}, rec.closes, + "only Closer modules must be closed, in reverse load order") +} + // Ensure stubModule satisfies the Module interface at compile time. var _ sessionmanager.Module = (*stubModule)(nil) diff --git a/internal/business/business.go b/internal/business/business.go index 3b25105a..75ad2665 100644 --- a/internal/business/business.go +++ b/internal/business/business.go @@ -13,12 +13,10 @@ import ( sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/business/server" "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/grpc" - sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" "github.com/openkcm/session-manager/internal/sessionwiring" ) -// Main starts both API servers +// Main starts the public HTTP API server and the configured apps. func Main(ctx context.Context, cfg *config.Config) error { c, cancelCause := sessionmanager.NewContext(ctx) defer cancelCause(nil) @@ -33,15 +31,23 @@ func Main(ctx context.Context, cfg *config.Config) error { return fmt.Errorf("loading trust module: %w", err) } + if _, err := c.LoadModule(&cfg.ValKey); err != nil { + return fmt.Errorf("loading session-store module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Credentials); err != nil { + return fmt.Errorf("loading credentials module: %w", err) + } + stopApps, err := startApps(c, cfg) if err != nil { return fmt.Errorf("starting apps: %w", err) } - // errChan is used to capture the first error and shutdown the servers. + // errChan captures the first error and triggers shutdown. errChan := make(chan error, 1) - // wg is used to wait for all servers to shutdown. + // wg is used to wait for all goroutines to shutdown. var wg sync.WaitGroup // start public HTTP REST API server @@ -49,11 +55,6 @@ func Main(ctx context.Context, cfg *config.Config) error { errChan <- publicMain(c, cfg) }) - // start internal gRPC API server - wg.Go(func() { - errChan <- internalMain(c, cfg) - }) - // wait for any error to initiate the shutdown err = <-errChan if err != nil { @@ -63,7 +64,6 @@ func Main(ctx context.Context, cfg *config.Config) error { stopErr := stopApps() cancelCause(err) - // wait for all servers to shutdown wg.Wait() return errors.Join(err, stopErr) @@ -98,39 +98,3 @@ func publicMain(ctx *sessionmanager.Context, cfg *config.Config) error { return server.StartHTTPServer(ctx, cfg, sessionManager) } - -// internalMain starts the gRPC private API server. -func internalMain(ctx *sessionmanager.Context, cfg *config.Config) error { - // Create session repository - valkeyClient, err := sessionwiring.ValkeyClient(cfg) - if err != nil { - return fmt.Errorf("failed to create valkey client: %w", err) - } - defer valkeyClient.Close() - sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - - credsBuilder, err := sessionwiring.CredsBuilder(cfg) - if err != nil { - return fmt.Errorf("failed to create a credentials builder: %w", err) - } - - trustMod, err := ctx.GetModule(cfg.Trust.Module()) - if err != nil { - return fmt.Errorf("getting trust module: %w", err) - } - - //nolint:forcetypeassert - trust := trustMod.(sessionmanager.Trust) - - // Initialize the gRPC servers. - trustsrv := grpc.NewTrustMappingServer(trust) - sessionsrv := grpc.NewSessionServer(ctx, - sessionRepo, - trust, - cfg.SessionManager.IdleSessionTimeout, - cfg.SessionManager.ClientAuth.ClientID, - grpc.WithTransportCredentials(credsBuilder), - ) - - return server.StartGRPCServer(ctx, cfg, trustsrv, sessionsrv) -} diff --git a/internal/business/business_test.go b/internal/business/business_test.go index de0292d2..c2519fcc 100644 --- a/internal/business/business_test.go +++ b/internal/business/business_test.go @@ -40,23 +40,6 @@ func TestPublicMain_ShortCSRFSecret(t *testing.T) { assert.Contains(t, err.Error(), "CSRF secret must be at least 32 bytes") } -func TestInternalMain_InvalidValkeyConfig(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - ctx, cancel := sessionmanager.NewContext(t.Context()) - defer cancel(nil) - - err := internalMain(ctx, cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create valkey client") -} - func TestMain_InvalidCSRFSecret(t *testing.T) { cfg := &config.Config{ SessionManager: config.SessionManager{ diff --git a/internal/business/housekeeper.go b/internal/business/housekeeper.go index d33a9dc1..fde823a5 100644 --- a/internal/business/housekeeper.go +++ b/internal/business/housekeeper.go @@ -17,6 +17,8 @@ func HousekeeperMain(ctx context.Context, cfg *config.Config) error { c, cancelCause := sessionmanager.NewContext(ctx) defer cancelCause(nil) + c = config.WithContext(c, cfg) + _, err := c.LoadModule(&cfg.Database) if err != nil { return fmt.Errorf("loading database module: %w", err) @@ -27,10 +29,18 @@ func HousekeeperMain(ctx context.Context, cfg *config.Config) error { return fmt.Errorf("loading trust module: %w", err) } + if _, err := c.LoadModule(&cfg.ValKey); err != nil { + return fmt.Errorf("loading session-store module: %w", err) + } + + if _, err := c.LoadModule(&cfg.Credentials); err != nil { + return fmt.Errorf("loading credentials module: %w", err) + } + //nolint:forcetypeassert trust := trustMod.(sessionmanager.Trust) - sessionManager, closeFn, err := sessionwiring.InitSessionManager(ctx, cfg, trust) + sessionManager, closeFn, err := sessionwiring.InitSessionManager(c, cfg, trust) if err != nil { return fmt.Errorf("failed to initialise the session manager: %w", err) } diff --git a/internal/business/server/grpc_server.go b/internal/business/server/grpc_server.go deleted file mode 100644 index b541bec3..00000000 --- a/internal/business/server/grpc_server.go +++ /dev/null @@ -1,59 +0,0 @@ -package server - -import ( - "context" - "net" - - "github.com/openkcm/common-sdk/pkg/commongrpc" - "github.com/samber/oops" - - sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" - trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" - slogctx "github.com/veqryn/slog-context" - - "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/grpc" -) - -func StartGRPCServer(ctx context.Context, cfg *config.Config, - trustsrv *grpc.TrustMappingServer, - sessionsrv *grpc.SessionServer, -) error { - grpcServer := commongrpc.NewServer(ctx, &cfg.GRPC.GRPCServer) - - // Register Trust server for the regional tenant manager - trustmappingv1.RegisterServiceServer(grpcServer, trustsrv) - // Register Session server for ExtAuthZ - sessionv1.RegisterServiceServer(grpcServer, sessionsrv) - - slogctx.Info(ctx, "Starting a listener", "address", cfg.GRPC.Address) - - listener, err := new(net.ListenConfig).Listen(ctx, "tcp", cfg.GRPC.Address) - if err != nil { - return oops.In("gRPC Server"). - WithContext(ctx). - Wrapf(err, "creating listener") - } - - slogctx.Info(ctx, "A listener started", "address", listener.Addr().String()) - - go func() { - slogctx.Info(ctx, "Serving a gRPC server", "address", listener.Addr().String()) - err := grpcServer.Serve(listener) - if err != nil { - slogctx.Error(ctx, "Failed to serve gRPC endpoint", "error", err) - } - - slogctx.Info(ctx, "Stopped gRPC server") - }() - - <-ctx.Done() - - shutdownCtx, cancel := context.WithTimeout(ctx, cfg.GRPC.ShutdownTimeout) - defer cancel() - - grpcServer.GracefulStop() - slogctx.Info(shutdownCtx, "Completed graceful shutdown of gRPC server") - - return nil -} diff --git a/internal/business/server/grpc_server_test.go b/internal/business/server/grpc_server_test.go deleted file mode 100644 index be0fb226..00000000 --- a/internal/business/server/grpc_server_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package server - -import ( - "context" - "testing" - "time" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" - - "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/grpc" -) - -func TestStartGRPCServer_ContextCancellation(t *testing.T) { - t.Run("gracefully shuts down when context is cancelled", func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - - cfg := &config.Config{ - GRPC: config.GRPCServer{ - GRPCServer: commoncfg.GRPCServer{ - Address: "localhost:0", // Use port 0 to get a random available port - }, - ShutdownTimeout: 1 * time.Second, - }, - } - - // Create minimal server instances - trustsrv := grpc.NewTrustMappingServer(nil) - sessionsrv := grpc.NewSessionServer(ctx, nil, nil, 0, "") - - // Start the server in a goroutine - errChan := make(chan error, 1) - go func() { - errChan <- StartGRPCServer(ctx, cfg, trustsrv, sessionsrv) - }() - - // Give the server a moment to start - time.Sleep(100 * time.Millisecond) - - // Cancel the context to trigger shutdown - cancel() - - // Wait for shutdown to complete - select { - case err := <-errChan: - assert.NoError(t, err) - case <-time.After(5 * time.Second): - t.Fatal("Server did not shut down within timeout") - } - }) -} diff --git a/internal/config/config.go b/internal/config/config.go index 876f4cf5..96690291 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,6 +26,7 @@ type Config struct { SessionManager SessionManager `yaml:"sessionManager"` Housekeeper Housekeeper `yaml:"housekeeper"` Trust Trust `yaml:"trust"` + Credentials Credentials `yaml:"credentials"` // Apps configures long-running components that satisfy the sessionmanager.App // interface. The map key is an operator-chosen name. Each entry MUST set @@ -42,8 +43,9 @@ type Config struct { // App is the per-entry configuration under the top-level apps: section. It // implements sessionmanager.ExtensionConfig so it can be passed to LoadApp. type App struct { - Mod string `yaml:"module"` - koanf *koanf.Koanf + Mod string `yaml:"module"` + Services []*ServiceCfg `yaml:"services"` + koanf *koanf.Koanf } func (c *App) setKoanf(ko *koanf.Koanf) { @@ -61,6 +63,29 @@ func (c *App) UnmarshalExtension(into sessionmanager.Module) error { return unmarshalExtension(into, c.koanf) } +// ServiceCfg is a per-entry configuration under an App's services: list. It +// implements sessionmanager.ExtensionConfig so an App's Provision can pass it +// to ctx.LoadModule. +type ServiceCfg struct { + Mod string `yaml:"module"` + koanf *koanf.Koanf +} + +func (c *ServiceCfg) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *ServiceCfg) Module() string { + return c.Mod +} + +func (c *ServiceCfg) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) +} + type Trust struct { Mod string `yaml:"module" default:"trust.module.oidc"` koanf *koanf.Koanf @@ -130,11 +155,53 @@ func (c *Database) UnmarshalExtension(into sessionmanager.Module) error { } type ValKey struct { + Mod string `yaml:"module" default:"sessionstore.module.valkey"` Host commoncfg.SourceRef `yaml:"host"` User commoncfg.SourceRef `yaml:"user"` Password commoncfg.SourceRef `yaml:"password"` Prefix string `yaml:"prefix"` SecretRef commoncfg.SecretRef `yaml:"secretRef"` + + koanf *koanf.Koanf +} + +func (c *ValKey) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *ValKey) Module() string { + return c.Mod +} + +func (c *ValKey) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) +} + +// Credentials is a thin top-level entry whose only purpose is to make the +// credentials module ID swappable. The actual auth-type/secret/mTLS data +// continues to live under sessionManager.clientAuth and the credentials +// module reads it via config.FromContext. +type Credentials struct { + Mod string `yaml:"module" default:"credentials.module.oauth2"` + koanf *koanf.Koanf +} + +func (c *Credentials) setKoanf(ko *koanf.Koanf) { + c.koanf = ko +} + +func (c *Credentials) Module() string { + return c.Mod +} + +func (c *Credentials) UnmarshalExtension(into sessionmanager.Module) error { + if c.koanf == nil { + return nil + } + return unmarshalExtension(into, c.koanf) } type SessionManager struct { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6520aaee..a265d5ad 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -70,3 +70,89 @@ func TestLoad_AppsAbsentIsNoop(t *testing.T) { assert.Empty(t, cfg.Apps) assert.Empty(t, cfg.AppsOrder) } + +// fakeServiceModule mirrors the per-service YAML fields a gRPC service module +// would expose. Used to assert per-entry koanf subtree wiring under +// apps[].services[]. +type fakeServiceModule struct { + Trust string `yaml:"trust"` + AllowHttpScheme bool `yaml:"allowHttpScheme"` +} + +func (*fakeServiceModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ID: "config_test.fake.service"} +} + +func TestLoad_ValkeyDefaultModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("valkey:\n prefix: foo\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "sessionstore.module.valkey", cfg.ValKey.Module()) + assert.Equal(t, "foo", cfg.ValKey.Prefix) +} + +func TestLoad_ValkeyCustomModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("valkey:\n module: my.custom.sessionstore\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "my.custom.sessionstore", cfg.ValKey.Module()) +} + +func TestLoad_CredentialsDefaultModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("# empty\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "credentials.module.oauth2", cfg.Credentials.Module()) +} + +func TestLoad_CredentialsCustomModule(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte("credentials:\n module: my.custom.credentials\n"), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + assert.Equal(t, "my.custom.credentials", cfg.Credentials.Module()) +} + +func TestLoad_AppServicesPerEntryKoanf(t *testing.T) { + yaml := ` +apps: + grpc: + module: app.module.grpcserver + services: + - module: service.module.grpc.session + trust: trust.module.oidc + allowHttpScheme: true + - module: service.module.grpc.trustmapping + trust: trust.module.alt +` + + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, configFile), []byte(yaml), 0o600)) + + cfg, err := Load("", dir) + require.NoError(t, err) + + require.Contains(t, cfg.Apps, "grpc") + app := cfg.Apps["grpc"] + assert.Equal(t, "app.module.grpcserver", app.Module()) + require.Len(t, app.Services, 2) + + assert.Equal(t, "service.module.grpc.session", app.Services[0].Module()) + assert.Equal(t, "service.module.grpc.trustmapping", app.Services[1].Module()) + + svc0 := &fakeServiceModule{} + require.NoError(t, app.Services[0].UnmarshalExtension(svc0)) + assert.Equal(t, "trust.module.oidc", svc0.Trust) + assert.True(t, svc0.AllowHttpScheme) + + svc1 := &fakeServiceModule{} + require.NoError(t, app.Services[1].UnmarshalExtension(svc1)) + assert.Equal(t, "trust.module.alt", svc1.Trust) +} diff --git a/internal/config/load.go b/internal/config/load.go index ef68e0f8..0d7692e9 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -101,6 +101,19 @@ func setKoanf(v reflect.Value, ko *koanf.Koanf) { val = val.Addr() } + // Slice-of-pointer fields need per-entry sub-koanfs that koanf can + // only construct from the parent path via Slices(); a positional + // Cut() does not work for array indices. + indirect := reflect.Indirect(val) + if indirect.Kind() == reflect.Slice && indirect.Type().Elem().Kind() == reflect.Pointer { + subs := ko.Slices(name) + n := min(indirect.Len(), len(subs)) + for i := range n { + setKoanf(indirect.Index(i), subs[i]) + } + continue + } + setKoanf(val, ko.Cut(name)) } case reflect.Map: diff --git a/internal/grpc/options.go b/internal/grpc/options.go deleted file mode 100644 index 6b5909ee..00000000 --- a/internal/grpc/options.go +++ /dev/null @@ -1,23 +0,0 @@ -package grpc - -import "github.com/openkcm/session-manager/internal/credentials" - -type SessionServerOption func(*SessionServer) - -func WithQueryParametersIntrospect(params []string) SessionServerOption { - return func(s *SessionServer) { - s.queryParametersIntrospect = params - } -} - -func WithAllowHttpScheme(allow bool) SessionServerOption { - return func(s *SessionServer) { - s.allowHttpScheme = allow - } -} - -func WithTransportCredentials(b credentials.Builder) SessionServerOption { - return func(s *SessionServer) { - s.newCreds = b - } -} diff --git a/internal/sessionwiring/sessionwiring.go b/internal/sessionwiring/sessionwiring.go index 94d036c7..935f92af 100644 --- a/internal/sessionwiring/sessionwiring.go +++ b/internal/sessionwiring/sessionwiring.go @@ -1,17 +1,14 @@ -// Package sessionwiring centralises the construction of long-lived -// session-manager dependencies (valkey client, credentials builder, the -// session.Manager itself) so callers in cmd/, internal/business, and apps -// configured via the apps: lifecycle loop can build them identically. +// Package sessionwiring centralises the construction of the long-lived +// session.Manager that the HTTP API server and the housekeeper subcommand +// share. The Valkey-backed session.Repository and OAuth2 credentials.Builder +// it needs are no longer built here; both come from the module registry, +// loaded by business.Main (or the housekeeper subcommand) before this is +// invoked. package sessionwiring import ( "context" - "errors" "fmt" - "log/slog" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/valkey-io/valkey-go" otlpaudit "github.com/openkcm/common-sdk/pkg/otlp/audit" @@ -19,29 +16,29 @@ import ( "github.com/openkcm/session-manager/internal/config" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/session" - sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" ) -const ( - insecure = "insecure" - mtls = "mtls" - clientSecret = "client_secret" // Alias for clientSecretPost. - clientSecretPost = "client_secret_post" -) +// credentialsBuilder is the interface satisfied by a credentials module +// (e.g. credentials.module.oauth2). Defined locally so this package does not +// need to import the credentials module. +type credentialsBuilder interface { + Builder() credentials.Builder +} // InitSessionManager builds a session.Manager from the supplied config and -// trust module. The returned closeFn must be invoked once the manager is no -// longer in use to release the underlying valkey client. -func InitSessionManager(ctx context.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { - valkeyClient, err := ValkeyClient(cfg) +// trust module, using session repository and credential modules already loaded +// in ctx. The returned closeFn is a no-op kept for API compatibility — the +// underlying valkey client is owned by the sessionstore module and closed by +// the framework's reverse-load-order shutdown. +func InitSessionManager(ctx *sessionmanager.Context, cfg *config.Config, trust sessionmanager.Trust) (_ *session.Manager, closeFn func(), _ error) { + repo, err := SessionRepository(ctx, cfg) if err != nil { - return nil, nil, fmt.Errorf("failed to create valkey client: %w", err) + return nil, nil, fmt.Errorf("getting session repository: %w", err) } - sessionRepo := sessionvalkey.NewRepository(valkeyClient, cfg.ValKey.Prefix) - credsBuilder, err := CredsBuilder(cfg) + credsBuilder, err := CredsBuilder(ctx, cfg) if err != nil { - return nil, nil, fmt.Errorf("failed to load http client: %w", err) + return nil, nil, fmt.Errorf("getting credentials builder: %w", err) } auditLogger, err := otlpaudit.NewLogger(&cfg.Audit) @@ -52,7 +49,7 @@ func InitSessionManager(ctx context.Context, cfg *config.Config, trust sessionma sessManager, err := session.NewManager(ctx, &cfg.SessionManager, trust, - sessionRepo, + repo, auditLogger, session.WithTransportCredentials(credsBuilder), ) @@ -60,71 +57,37 @@ func InitSessionManager(ctx context.Context, cfg *config.Config, trust sessionma return nil, nil, fmt.Errorf("failed to create session manager: %w", err) } - return sessManager, valkeyClient.Close, nil + return sessManager, func() {}, nil } -// ValkeyClient creates a valkey client from the valkey-related fields on cfg. -func ValkeyClient(cfg *config.Config) (valkey.Client, error) { - valkeyHost, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Host) +// SessionRepository resolves the session repository module loaded under the +// ID configured in cfg.ValKey.Module() and returns its session.Repository. +func SessionRepository(ctx *sessionmanager.Context, cfg *config.Config) (session.Repository, error) { + mod, err := ctx.GetModule(cfg.ValKey.Module()) if err != nil { - return nil, fmt.Errorf("failed to load valkey host: %w", err) + return nil, fmt.Errorf("getting session-store module %q: %w", cfg.ValKey.Module(), err) } - - valkeyUsername, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.User) - if err != nil { - return nil, fmt.Errorf("failed to load valkey username: %w", err) + repo, ok := mod.(session.Repository) + if !ok { + return nil, fmt.Errorf("module %q does not implement session.Repository", cfg.ValKey.Module()) } + return repo, nil +} - valkeyPassword, err := commoncfg.LoadValueFromSourceRef(cfg.ValKey.Password) +// CredsBuilder resolves the credentials module loaded under the ID configured +// in cfg.Credentials.Module() and returns its credentials.Builder. +func CredsBuilder(ctx *sessionmanager.Context, cfg *config.Config) (credentials.Builder, error) { + mod, err := ctx.GetModule(cfg.Credentials.Module()) if err != nil { - return nil, fmt.Errorf("failed to load valkey password: %w", err) - } - - valkeyOpts := valkey.ClientOption{ - InitAddress: []string{string(valkeyHost)}, - Username: string(valkeyUsername), - Password: string(valkeyPassword), + return nil, fmt.Errorf("getting credentials module %q: %w", cfg.Credentials.Module(), err) } - - if cfg.ValKey.SecretRef.Type == commoncfg.MTLSSecretType { - tlsConfig, err := commoncfg.LoadMTLSConfig(&cfg.ValKey.SecretRef.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load valkey mTLS config from secret ref: %w", err) - } - valkeyOpts.TLSConfig = tlsConfig - } - - valkeyClient, err := valkey.NewClient(valkeyOpts) - if err != nil { - return nil, fmt.Errorf("failed to create a new valkey client: %w", err) + cb, ok := mod.(credentialsBuilder) + if !ok { + return nil, fmt.Errorf("module %q does not expose Builder()", cfg.Credentials.Module()) } - return valkeyClient, nil + return cb.Builder(), nil } -// CredsBuilder returns a credentials.Builder that matches the configured -// client-auth strategy. -func CredsBuilder(cfg *config.Config) (credentials.Builder, error) { - switch cfg.SessionManager.ClientAuth.Type { - case mtls: - tlsConfig, err := commoncfg.LoadMTLSConfig(cfg.SessionManager.ClientAuth.MTLS) - if err != nil { - return nil, fmt.Errorf("failed to load mTLS config: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { return credentials.NewTLS(clientID, tlsConfig) }, nil - case clientSecretPost, clientSecret: - secret, err := commoncfg.LoadValueFromSourceRef(cfg.SessionManager.ClientAuth.ClientSecret) - if err != nil { - return nil, fmt.Errorf("failed to load client secret: %w", err) - } - - return func(clientID string) credentials.TransportCredentials { - return credentials.NewClientSecretPost(clientID, string(secret)) - }, nil - case insecure: - slog.Warn("insecure credentials are used. Do not use this in production") - return func(clientID string) credentials.TransportCredentials { return credentials.NewInsecure(clientID) }, nil - default: - return nil, errors.New("unknown Client Auth type") - } -} +// Reference to context.Context to keep imports stable for callers using +// (ctx context.Context) signatures. +var _ context.Context = (*sessionmanager.Context)(nil) diff --git a/internal/sessionwiring/sessionwiring_test.go b/internal/sessionwiring/sessionwiring_test.go deleted file mode 100644 index 794f75b3..00000000 --- a/internal/sessionwiring/sessionwiring_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package sessionwiring - -import ( - "testing" - - "github.com/openkcm/common-sdk/pkg/commoncfg" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/openkcm/session-manager/internal/config" - "github.com/openkcm/session-manager/internal/credentials" -) - -func TestCredsBuilder_MTLS(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "mtls", - ClientID: "test-client", - MTLS: &commoncfg.MTLS{ - Cert: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - _, err := CredsBuilder(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load mTLS config") -} - -func TestCredsBuilder_ClientSecret(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "client_secret", - ClientID: "test-client", - ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "test-secret"}, - }, - }, - } - - builder, err := CredsBuilder(cfg) - require.NoError(t, err) - require.NotNil(t, builder) - - creds := builder(cfg.SessionManager.ClientAuth.ClientID) - clientSecretCreds, ok := creds.(*credentials.ClientSecretPost) - require.True(t, ok) - - assert.Equal(t, "test-client", clientSecretCreds.ClientID) - assert.Equal(t, "test-secret", clientSecretCreds.ClientSecret) -} - -func TestCredsBuilder_Insecure(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "insecure", - ClientID: "test-client", - }, - }, - } - - builder, err := CredsBuilder(cfg) - require.NoError(t, err) - assert.IsType(t, &credentials.Insecure{}, builder("")) -} - -func TestCredsBuilder_UnknownType(t *testing.T) { - cfg := &config.Config{ - SessionManager: config.SessionManager{ - ClientAuth: config.ClientAuth{ - Type: "unknown", - ClientID: "test-client", - }, - }, - } - - _, err := CredsBuilder(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown Client Auth type") -} - -func TestValkeyClient_InvalidHostRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := ValkeyClient(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey host") -} - -func TestValkeyClient_InvalidUserRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - }, - } - - _, err := ValkeyClient(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey username") -} - -func TestValkeyClient_InvalidPasswordRef(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/file"}}, - }, - } - - _, err := ValkeyClient(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey password") -} - -func TestValkeyClient_WithMTLS(t *testing.T) { - cfg := &config.Config{ - ValKey: config.ValKey{ - Host: commoncfg.SourceRef{Source: "embedded", Value: "localhost:6379"}, - User: commoncfg.SourceRef{Source: "embedded", Value: "user"}, - Password: commoncfg.SourceRef{Source: "embedded", Value: "pass"}, - SecretRef: commoncfg.SecretRef{ - Type: commoncfg.MTLSSecretType, - MTLS: commoncfg.MTLS{ - Cert: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/cert.pem"}}, - CertKey: commoncfg.SourceRef{Source: "file", File: commoncfg.CredentialFile{Path: "/nonexistent/key.pem"}}, - }, - }, - }, - } - - _, err := ValkeyClient(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load valkey mTLS config from secret ref") -} diff --git a/modules/app/grpcserver/module.go b/modules/app/grpcserver/module.go new file mode 100644 index 00000000..613f974f --- /dev/null +++ b/modules/app/grpcserver/module.go @@ -0,0 +1,160 @@ +// Package grpcserver provides the app.module.grpcserver app module: a +// long-running gRPC server that hosts service modules registered through its +// services: config block. The Service interface is exported here so that +// service modules can satisfy it without importing google.golang.org/grpc +// from the top-level sessionmanager package. +package grpcserver + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/openkcm/common-sdk/pkg/commongrpc" + "google.golang.org/grpc" + + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" +) + +const moduleID = "app.module.grpcserver" + +// Service is the interface that every gRPC service module loaded under an +// app.module.grpcserver entry must satisfy. The grpc app calls Register on +// each child service, in declaration order, before invoking Serve. +type Service interface { + Register(s grpc.ServiceRegistrar) +} + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the gRPC server app. Its lifecycle: +// - Provision: load every service module listed under services: from the +// config; type-assert each against Service; collect them in declaration +// order. +// - Start: build the underlying *grpc.Server via commongrpc.NewServer using +// the top-level cfg.GRPC block, register every collected service onto it, +// listen on cfg.GRPC.Address, and begin Serve in a goroutine. +// - Stop: GracefulStop bounded by cfg.GRPC.ShutdownTimeout; if the timeout +// fires, fall back to a forceful Stop. +type Module struct { + Mod string `yaml:"module"` + Services []*config.ServiceCfg `yaml:"services"` + + ctx context.Context //nolint:containedctx + cfg *config.Config + services []Service + + server *grpc.Server + listener net.Listener + + stopOnce sync.Once + stopErr error + + serveDone chan struct{} +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + m.ctx = ctx + m.cfg = cfg + + if len(m.Services) == 0 { + return errors.New("app.module.grpcserver requires at least one service under services") + } + + m.services = make([]Service, 0, len(m.Services)) + for i, svcCfg := range m.Services { + mod, err := ctx.LoadModule(svcCfg) + if err != nil { + return fmt.Errorf("loading service[%d] %q: %w", i, svcCfg.Module(), err) + } + svc, ok := mod.(Service) + if !ok { + return fmt.Errorf("service[%d] module %q does not implement grpcserver.Service", i, svcCfg.Module()) + } + m.services = append(m.services, svc) + } + + return nil +} + +func (m *Module) Start() error { + m.server = commongrpc.NewServer(m.ctx, &m.cfg.GRPC.GRPCServer) + + for _, svc := range m.services { + svc.Register(m.server) + } + + listener, err := new(net.ListenConfig).Listen(m.ctx, "tcp", m.cfg.GRPC.Address) + if err != nil { + return fmt.Errorf("listening on %s: %w", m.cfg.GRPC.Address, err) + } + m.listener = listener + + slogctx.Info(m.ctx, "Starting a gRPC listener", "address", listener.Addr().String()) + + m.serveDone = make(chan struct{}) + go func() { + defer close(m.serveDone) + if err := m.server.Serve(listener); err != nil { + slogctx.Error(m.ctx, "gRPC server stopped with error", "error", err) + } + }() + + return nil +} + +func (m *Module) Stop() error { + m.stopOnce.Do(func() { + if m.server == nil { + return + } + + gracefulDone := make(chan struct{}) + go func() { + m.server.GracefulStop() + close(gracefulDone) + }() + + timeout := m.cfg.GRPC.ShutdownTimeout + if timeout <= 0 { + <-gracefulDone + } else { + tctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + select { + case <-gracefulDone: + case <-tctx.Done(): + slogctx.Warn(m.ctx, "gRPC graceful stop exceeded timeout; forcing Stop", "timeout", timeout) + m.server.Stop() + <-gracefulDone + } + } + + if m.serveDone != nil { + <-m.serveDone + } + }) + return m.stopErr +} diff --git a/modules/app/grpcserver/module_test.go b/modules/app/grpcserver/module_test.go new file mode 100644 index 00000000..be8995e9 --- /dev/null +++ b/modules/app/grpcserver/module_test.go @@ -0,0 +1,131 @@ +package grpcserver_test + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/modules/app/grpcserver" +) + +// fakeService satisfies grpcserver.Service. It records that Register was +// called on a non-nil ServiceRegistrar. The test then exercises Start/Stop +// without registering any real proto services. +type fakeService struct { + mu sync.Mutex + registered bool +} + +func (f *fakeService) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: "test.fake.service", + New: func() sessionmanager.Module { return f }, + } +} + +func (f *fakeService) Register(_ grpc.ServiceRegistrar) { + f.mu.Lock() + defer f.mu.Unlock() + f.registered = true +} + +// notService is a Module that does NOT implement grpcserver.Service. +type notService struct{ id string } + +func (n *notService) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: n.id, + New: func() sessionmanager.Module { return n }, + } +} + +func newCtx(t *testing.T) (*sessionmanager.Context, *config.Config) { + t.Helper() + cfg := &config.Config{} + cfg.GRPC = config.GRPCServer{ + GRPCServer: commoncfg.GRPCServer{Address: "127.0.0.1:0"}, + ShutdownTimeout: 2 * time.Second, + } + // Find a free port the deterministic way. + l, err := new(net.ListenConfig).Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + cfg.GRPC.Address = addr + + ctx, cancel := sessionmanager.NewContext(t.Context()) + t.Cleanup(func() { cancel(nil) }) + ctx = config.WithContext(ctx, cfg) + return ctx, cfg +} + +func TestModule_StartRegistersServicesAndStops(t *testing.T) { + ctx, _ := newCtx(t) + + fakeID := "test.fake.service." + t.Name() + svc := &fakeService{} + sessionmanager.RegisterModule(&customMod{id: fakeID, mod: svc}) + + m := &grpcserver.Module{ + Services: []*config.ServiceCfg{newSvcCfg(fakeID)}, + } + require.NoError(t, m.Provision(ctx)) + require.NoError(t, m.Start()) + + assert.True(t, svc.registered, "service must be registered before Serve") + + require.NoError(t, m.Stop()) + + // Stop is idempotent. + require.NoError(t, m.Stop()) +} + +func TestModule_NonServiceUnderServicesIsRejected(t *testing.T) { + ctx, _ := newCtx(t) + + id := "test.notservice." + t.Name() + sessionmanager.RegisterModule(&customMod{id: id, mod: ¬Service{id: id}}) + + m := &grpcserver.Module{ + Services: []*config.ServiceCfg{newSvcCfg(id)}, + } + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not implement") +} + +func TestModule_EmptyServicesRejected(t *testing.T) { + ctx, _ := newCtx(t) + m := &grpcserver.Module{} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one service") +} + +// customMod is a tiny ExtensionConfig + module registration helper. +type customMod struct { + id string + mod sessionmanager.Module +} + +func (c *customMod) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: c.id, + New: func() sessionmanager.Module { return c.mod }, + } +} + +// newSvcCfg builds a config.ServiceCfg pointing at the given module ID. +// We bypass koanf entirely; the module just needs Module() to return the ID. +func newSvcCfg(modID string) *config.ServiceCfg { + c := &config.ServiceCfg{Mod: modID} + return c +} diff --git a/modules/credentials/oauth2/module.go b/modules/credentials/oauth2/module.go new file mode 100644 index 00000000..5a95747e --- /dev/null +++ b/modules/credentials/oauth2/module.go @@ -0,0 +1,90 @@ +// Package oauth2 provides the credentials.module.oauth2 module: a +// credentials.Builder that produces transport credentials for OAuth2/OIDC +// client authentication. Source data lives under sessionManager.clientAuth in +// the top-level config and is read via config.FromContext. +package oauth2 + +import ( + "errors" + "fmt" + "log/slog" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" +) + +const moduleID = "credentials.module.oauth2" + +const ( + authMTLS = "mtls" + authClientSecret = "client_secret" + authClientSecretPost = "client_secret_post" + authInsecure = "insecure" +) + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the credentials.module.oauth2 module. It exposes a +// credentials.Builder constructed from sessionManager.clientAuth. +type Module struct { + Mod string `yaml:"module"` + + builder credentials.Builder +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + + clientAuth := cfg.SessionManager.ClientAuth + switch clientAuth.Type { + case authMTLS: + tlsConfig, err := commoncfg.LoadMTLSConfig(clientAuth.MTLS) + if err != nil { + return fmt.Errorf("loading mTLS config: %w", err) + } + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewTLS(clientID, tlsConfig) + } + case authClientSecret, authClientSecretPost: + secret, err := commoncfg.LoadValueFromSourceRef(clientAuth.ClientSecret) + if err != nil { + return fmt.Errorf("loading client secret: %w", err) + } + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewClientSecretPost(clientID, string(secret)) + } + case authInsecure: + slog.Warn("insecure credentials are used. Do not use this in production") + m.builder = func(clientID string) credentials.TransportCredentials { + return credentials.NewInsecure(clientID) + } + default: + return fmt.Errorf("unknown client auth type %q", clientAuth.Type) + } + + return nil +} + +// Builder returns the credentials.Builder produced during Provision. +func (m *Module) Builder() credentials.Builder { + return m.builder +} diff --git a/modules/credentials/oauth2/module_test.go b/modules/credentials/oauth2/module_test.go new file mode 100644 index 00000000..d811526a --- /dev/null +++ b/modules/credentials/oauth2/module_test.go @@ -0,0 +1,81 @@ +package oauth2_test + +import ( + "testing" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + credentialsoauth2 "github.com/openkcm/session-manager/modules/credentials/oauth2" +) + +func TestModule_RegistrationAndID(t *testing.T) { + info, err := sessionmanager.GetModule("credentials.module.oauth2") + require.NoError(t, err) + assert.Equal(t, "credentials.module.oauth2", info.ID) + + mod := info.New() + require.NotNil(t, mod) + _, ok := mod.(*credentialsoauth2.Module) + assert.True(t, ok, "New() must return *Module") +} + +func provisionWithAuth(t *testing.T, auth config.ClientAuth) (*credentialsoauth2.Module, error) { + t.Helper() + cfg := &config.Config{} + cfg.SessionManager.ClientAuth = auth + + ctx, cancel := sessionmanager.NewContext(t.Context()) + t.Cleanup(func() { cancel(nil) }) + ctx = config.WithContext(ctx, cfg) + + m := new(credentialsoauth2.Module) + return m, m.Provision(ctx) +} + +func TestModule_ProvisionInsecure(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{Type: "insecure"}) + require.NoError(t, err) + require.NotNil(t, m.Builder()) + + creds := m.Builder()("client-id") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionClientSecret(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{ + Type: "client_secret", + ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "shh"}, + }) + require.NoError(t, err) + creds := m.Builder()("cid") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionClientSecretPost(t *testing.T) { + m, err := provisionWithAuth(t, config.ClientAuth{ + Type: "client_secret_post", + ClientSecret: commoncfg.SourceRef{Source: "embedded", Value: "shh"}, + }) + require.NoError(t, err) + creds := m.Builder()("cid") + assert.NotNil(t, creds) +} + +func TestModule_ProvisionUnknownTypeFails(t *testing.T) { + _, err := provisionWithAuth(t, config.ClientAuth{Type: "totally-bogus"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "totally-bogus") +} + +func TestModule_ProvisionWithoutConfigFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := new(credentialsoauth2.Module) + err := m.Provision(ctx) + require.Error(t, err) +} diff --git a/internal/grpc/import_test.go b/modules/grpc/session/import_test.go similarity index 94% rename from internal/grpc/import_test.go rename to modules/grpc/session/import_test.go index 6c5da46f..cfe5ca6f 100644 --- a/internal/grpc/import_test.go +++ b/modules/grpc/session/import_test.go @@ -1,4 +1,4 @@ -package grpc_test +package session_test import ( _ "unsafe" diff --git a/modules/grpc/session/module.go b/modules/grpc/session/module.go new file mode 100644 index 00000000..187abc01 --- /dev/null +++ b/modules/grpc/session/module.go @@ -0,0 +1,113 @@ +// Package session provides the service.module.grpc.session module: a gRPC +// service module that registers the kms.api.cmk.sessionmanager.session.v1.Service +// proto onto a grpc.ServiceRegistrar supplied by app.module.grpcserver. +package session + +import ( + "errors" + "fmt" + + "google.golang.org/grpc" + + sessionv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/session/v1" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/config" + "github.com/openkcm/session-manager/internal/credentials" + internalsession "github.com/openkcm/session-manager/internal/session" +) + +const moduleID = "service.module.grpc.session" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// credentialsBuilder is the interface satisfied by a credentials module +// (e.g. credentials.module.oauth2). +type credentialsBuilder interface { + Builder() credentials.Builder +} + +// Module is the service.module.grpc.session module. It wires its three +// dependencies (trust, session store, credentials) by ID via ctx.GetModule +// and owns a *Server that implements the proto. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + SessionStore string `yaml:"sessionStore" default:"sessionstore.module.valkey"` + Credentials string `yaml:"credentials" default:"credentials.module.oauth2"` + + AllowHttpScheme bool `yaml:"allowHttpScheme"` + QueryParametersIntrospect []string `yaml:"queryParametersIntrospect"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + cfg, ok := config.FromContext(ctx) + if !ok { + return errors.New("config not found in context") + } + + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + storeMod, err := ctx.GetModule(m.SessionStore) + if err != nil { + return fmt.Errorf("getting session-store module %q: %w", m.SessionStore, err) + } + repo, ok := storeMod.(internalsession.Repository) + if !ok { + return fmt.Errorf("module %q does not implement session.Repository", m.SessionStore) + } + + credsMod, err := ctx.GetModule(m.Credentials) + if err != nil { + return fmt.Errorf("getting credentials module %q: %w", m.Credentials, err) + } + creds, ok := credsMod.(credentialsBuilder) + if !ok { + return fmt.Errorf("module %q does not expose Builder()", m.Credentials) + } + + opts := []Option{ + WithTransportCredentials(creds.Builder()), + WithAllowHttpScheme(m.AllowHttpScheme), + } + if m.QueryParametersIntrospect != nil { + opts = append(opts, WithQueryParametersIntrospect(m.QueryParametersIntrospect)) + } + + m.server = NewServer( + ctx, + repo, + trust, + cfg.SessionManager.IdleSessionTimeout, + cfg.SessionManager.ClientAuth.ClientID, + opts..., + ) + + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + sessionv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/session/options.go b/modules/grpc/session/options.go new file mode 100644 index 00000000..26731c2b --- /dev/null +++ b/modules/grpc/session/options.go @@ -0,0 +1,23 @@ +package session + +import "github.com/openkcm/session-manager/internal/credentials" + +type Option func(*Server) + +func WithQueryParametersIntrospect(params []string) Option { + return func(s *Server) { + s.queryParametersIntrospect = params + } +} + +func WithAllowHttpScheme(allow bool) Option { + return func(s *Server) { + s.allowHttpScheme = allow + } +} + +func WithTransportCredentials(b credentials.Builder) Option { + return func(s *Server) { + s.newCreds = b + } +} diff --git a/internal/grpc/session.go b/modules/grpc/session/server.go similarity index 86% rename from internal/grpc/session.go rename to modules/grpc/session/server.go index a71ad1e9..312eb26d 100644 --- a/internal/grpc/session.go +++ b/modules/grpc/session/server.go @@ -1,4 +1,4 @@ -package grpc +package session import ( "context" @@ -25,17 +25,17 @@ import ( sessionmanager "github.com/openkcm/session-manager" "github.com/openkcm/session-manager/internal/credentials" "github.com/openkcm/session-manager/internal/debugtools" - "github.com/openkcm/session-manager/internal/session" + internalsession "github.com/openkcm/session-manager/internal/session" ) const defaultIntrospectionCacheExpiration = 30 * time.Second var debugSettingSMDumpTransport = debugtools.NewSetting("smdumptransport") -type SessionServer struct { +type Server struct { sessionv1.UnimplementedServiceServer - sessionRepo session.Repository + sessionRepo internalsession.Repository trust sessionmanager.Trust newCreds credentials.Builder @@ -48,15 +48,15 @@ type SessionServer struct { introspectionCache *ttlcache.Cache[string, oidc.Introspection] } -func NewSessionServer( +func NewServer( ctx context.Context, - sessionRepo session.Repository, + sessionRepo internalsession.Repository, trust sessionmanager.Trust, idleSessionTimeout time.Duration, clientID string, - opts ...SessionServerOption, -) *SessionServer { - s := &SessionServer{ + opts ...Option, +) *Server { + s := &Server{ sessionRepo: sessionRepo, trust: trust, idleSessionTimeout: idleSessionTimeout, @@ -79,7 +79,7 @@ func NewSessionServer( return s } -func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessionRequest) (*sessionv1.GetSessionResponse, error) { +func (s *Server) GetSession(ctx context.Context, req *sessionv1.GetSessionRequest) (*sessionv1.GetSessionResponse, error) { tracer := otel.GetTracerProvider() ctx, span := tracer.Tracer("").Start(ctx, "get_session") defer span.End() @@ -194,7 +194,7 @@ func (s *SessionServer) GetSession(ctx context.Context, req *sessionv1.GetSessio return response, nil } -func (s *SessionServer) GetOIDCProvider(ctx context.Context, req *sessionv1.GetOIDCProviderRequest) (*sessionv1.GetOIDCProviderResponse, error) { +func (s *Server) GetOIDCProvider(ctx context.Context, req *sessionv1.GetOIDCProviderRequest) (*sessionv1.GetOIDCProviderResponse, error) { tracer := otel.GetTracerProvider() ctx, span := tracer.Tracer("").Start(ctx, "get_oidc_provider") defer span.End() @@ -218,7 +218,23 @@ func (s *SessionServer) GetOIDCProvider(ctx context.Context, req *sessionv1.GetO }, nil } -func (s *SessionServer) getClientID(oidcTrust *oidcv1.OIDC) string { +func (s *Server) GetTrust(ctx context.Context, req *sessionv1.GetTrustRequest) (*sessionv1.GetTrustResponse, error) { + tracer := otel.GetTracerProvider() + ctx, span := tracer.Tracer("").Start(ctx, "get_trust") + defer span.End() + + trust, err := s.trust.Get(ctx, req.GetTenantId()) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "failed to get an oidc provider") + return nil, fmt.Errorf("getting odic provider: %w", err) + } + + span.SetStatus(codes.Ok, "") + return &sessionv1.GetTrustResponse{Trust: trust}, nil +} + +func (s *Server) getClientID(oidcTrust *oidcv1.OIDC) string { if clientID := oidcTrust.GetClientId(); clientID != "" { return clientID } @@ -226,7 +242,7 @@ func (s *SessionServer) getClientID(oidcTrust *oidcv1.OIDC) string { return s.clientID } -func (s *SessionServer) httpClient(oidcTrust *oidcv1.OIDC) *http.Client { +func (s *Server) httpClient(oidcTrust *oidcv1.OIDC) *http.Client { creds := s.newCreds(s.getClientID(oidcTrust)) transport := creds.Transport() if debugSettingSMDumpTransport.Value() == "1" { @@ -238,7 +254,7 @@ func (s *SessionServer) httpClient(oidcTrust *oidcv1.OIDC) *http.Client { } } -func (s *SessionServer) introspectToken(ctx context.Context, token string, oidcTrust *oidcv1.OIDC) (oidc.Introspection, error) { +func (s *Server) introspectToken(ctx context.Context, token string, oidcTrust *oidcv1.OIDC) (oidc.Introspection, error) { // first check the cache for a recent introspection result for this token hashedSuffix := sha256.Sum256([]byte(token)) cacheKey := base64.RawURLEncoding.EncodeToString(hashedSuffix[:]) diff --git a/internal/grpc/session_test.go b/modules/grpc/session/server_test.go similarity index 89% rename from internal/grpc/session_test.go rename to modules/grpc/session/server_test.go index 04885d28..de649b26 100644 --- a/internal/grpc/session_test.go +++ b/modules/grpc/session/server_test.go @@ -1,4 +1,4 @@ -package grpc_test +package session_test import ( "encoding/json" @@ -19,9 +19,9 @@ import ( oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/internal/grpc" - "github.com/openkcm/session-manager/internal/session" + internalsession "github.com/openkcm/session-manager/internal/session" sessionmock "github.com/openkcm/session-manager/internal/session/mock" + "github.com/openkcm/session-manager/modules/grpc/session" mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" ) @@ -33,7 +33,7 @@ func TestNewSessionServer(t *testing.T) { trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, sessionRepo, trust, idleSessionTimeout, "") + server := session.NewServer(ctx, sessionRepo, trust, idleSessionTimeout, "") assert.NotNil(t, server) }) @@ -44,12 +44,12 @@ func TestNewSessionServer(t *testing.T) { trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, + server := session.NewServer(ctx, sessionRepo, trust, idleSessionTimeout, "", - grpc.WithQueryParametersIntrospect([]string{"param1", "param2"}), + session.WithQueryParametersIntrospect([]string{"param1", "param2"}), ) assert.NotNil(t, server) @@ -61,7 +61,7 @@ func TestNewSessionServer(t *testing.T) { trust := newTrust(trustRepo) idleSessionTimeout := 90 * time.Minute - server := grpc.NewSessionServer(ctx, + server := session.NewServer(ctx, sessionRepo, trust, idleSessionTimeout, @@ -96,13 +96,13 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-123", TenantID: "tenant-123", Fingerprint: "fingerprint-123", Issuer: testServer.URL, AccessToken: "access-token-123", - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-123", GivenName: "John", FamilyName: "Doe", @@ -128,8 +128,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ SessionId: "session-123", @@ -172,13 +172,13 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-groups", TenantID: "tenant-groups", Fingerprint: "fingerprint-groups", Issuer: testServer.URL, AccessToken: "access-token-groups", - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-groups", Groups: []string{"session-group1", "session-group2"}, }, @@ -199,8 +199,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -231,12 +231,12 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-456", TenantID: "tenant-456", Fingerprint: "fingerprint-456", Issuer: testServer.URL, - Claims: session.Claims{ + Claims: internalsession.Claims{ Subject: "user-456", }, } @@ -256,8 +256,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -280,7 +280,7 @@ func TestGetSession(t *testing.T) { ) trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-123", @@ -296,7 +296,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - session not active", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-789", TenantID: "tenant-789", Fingerprint: "fingerprint-789", @@ -309,7 +309,7 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-789", @@ -329,14 +329,14 @@ func TestGetSession(t *testing.T) { sessionmock.WithLoadSessionError(errors.New("load error")), ) // Create a session and mark as active but LoadSession will error - sess := session.Session{ID: "session-fail"} + sess := internalsession.Session{ID: "session-fail"} err := sessionRepo.StoreSession(ctx, sess) assert.NoError(t, err) _ = sessionRepo.BumpActive(ctx, sess.ID, 1*time.Hour) trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fail", @@ -352,7 +352,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - trust not found", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-no-provider", TenantID: "tenant-no-provider", Fingerprint: "fingerprint-123", @@ -367,7 +367,7 @@ func TestGetSession(t *testing.T) { // No trust added to repo trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-no-provider", @@ -383,7 +383,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - trust is blocked", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-blocked", TenantID: "tenant-blocked", Fingerprint: "fingerprint-123", @@ -405,7 +405,7 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-blocked", @@ -433,7 +433,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - fingerprint mismatch", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-fingerprint", TenantID: "tenant-fingerprint", Fingerprint: "correct-fingerprint", @@ -455,7 +455,7 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-fingerprint", @@ -471,7 +471,7 @@ func TestGetSession(t *testing.T) { }) t.Run("invalid - tenant ID mismatch", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-tenant", TenantID: "correct-tenant", Fingerprint: "fingerprint-123", @@ -493,7 +493,7 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-tenant", @@ -509,7 +509,7 @@ func TestGetSession(t *testing.T) { }) t.Run("error - GetOpenIDConfig fails", func(t *testing.T) { - sess := session.Session{ + sess := internalsession.Session{ ID: "session-config-fail", TenantID: "tenant-config-fail", Fingerprint: "fingerprint-123", @@ -531,7 +531,7 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetSessionRequest{ SessionId: "session-config-fail", @@ -562,7 +562,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-introspect-fail", TenantID: "tenant-introspect-fail", Fingerprint: "fingerprint-123", @@ -585,8 +585,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -620,7 +620,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-inactive-token", TenantID: "tenant-inactive-token", Fingerprint: "fingerprint-123", @@ -643,8 +643,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -671,7 +671,7 @@ func TestGetSession(t *testing.T) { })) defer testServer.Close() - sess := session.Session{ + sess := internalsession.Session{ ID: "session-bump-fail", TenantID: "tenant-bump-fail", Fingerprint: "fingerprint-123", @@ -694,8 +694,8 @@ func TestGetSession(t *testing.T) { trustRepo := mocktrust.NewInMemRepository(mocktrust.WithTrust(trustData)) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", - grpc.WithAllowHttpScheme(true), + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", + session.WithAllowHttpScheme(true), ) req := &sessionv1.GetSessionRequest{ @@ -716,7 +716,7 @@ func TestWithQueryParametersIntrospect(t *testing.T) { ctx := t.Context() t.Run("sets query parameters correctly", func(t *testing.T) { params := []string{"param1", "param2", "param3"} - opt := grpc.WithQueryParametersIntrospect(params) + opt := session.WithQueryParametersIntrospect(params) assert.NotNil(t, opt) @@ -725,7 +725,7 @@ func TestWithQueryParametersIntrospect(t *testing.T) { trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "", opt) + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "", opt) assert.NotNil(t, server) }) @@ -748,7 +748,7 @@ func TestGetOIDCProvider(t *testing.T) { trust := newTrust(trustRepo) sessionRepo := sessionmock.NewInMemRepository() - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", @@ -768,7 +768,7 @@ func TestGetOIDCProvider(t *testing.T) { sessionRepo := sessionmock.NewInMemRepository() trustRepo := mocktrust.NewInMemRepository() trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "non-existent-tenant", } @@ -786,7 +786,7 @@ func TestGetOIDCProvider(t *testing.T) { mocktrust.WithGetError(errors.New("database connection error")), ) trust := newTrust(trustRepo) - server := grpc.NewSessionServer(ctx, sessionRepo, trust, 90*time.Minute, "") + server := session.NewServer(ctx, sessionRepo, trust, 90*time.Minute, "") req := &sessionv1.GetOIDCProviderRequest{ TenantId: "tenant-123", } diff --git a/internal/grpc/violations.go b/modules/grpc/session/violations.go similarity index 77% rename from internal/grpc/violations.go rename to modules/grpc/session/violations.go index 5479ea89..b4313e38 100644 --- a/internal/grpc/violations.go +++ b/modules/grpc/session/violations.go @@ -1,4 +1,4 @@ -package grpc +package session const ( violationTenantBlocked = "tenant_blocked" diff --git a/modules/grpc/trustmapping/import_test.go b/modules/grpc/trustmapping/import_test.go new file mode 100644 index 00000000..958539cf --- /dev/null +++ b/modules/grpc/trustmapping/import_test.go @@ -0,0 +1,12 @@ +package trustmapping_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/modules/grpc/trustmapping/module.go b/modules/grpc/trustmapping/module.go new file mode 100644 index 00000000..bc949a95 --- /dev/null +++ b/modules/grpc/trustmapping/module.go @@ -0,0 +1,60 @@ +// Package trustmapping provides the service.module.grpc.trustmapping module: +// a gRPC service module that registers the +// kms.api.cmk.sessionmanager.trustmapping.v1.Service proto onto a +// grpc.ServiceRegistrar supplied by app.module.grpcserver. +package trustmapping + +import ( + "fmt" + + "google.golang.org/grpc" + + trustmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/trustmapping/v1" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "service.module.grpc.trustmapping" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the service.module.grpc.trustmapping module. It owns a Server +// that implements the trustmapping proto and resolves its single dependency +// (a sessionmanager.Trust implementation) by ID via ctx.GetModule. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + m.server = NewServer(trust) + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + trustmappingv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/trustmapping/module_test.go b/modules/grpc/trustmapping/module_test.go new file mode 100644 index 00000000..33923884 --- /dev/null +++ b/modules/grpc/trustmapping/module_test.go @@ -0,0 +1,83 @@ +package trustmapping_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + tmmod "github.com/openkcm/session-manager/modules/grpc/trustmapping" + _ "github.com/openkcm/session-manager/modules/standard" +) + +// stubTrust satisfies sessionmanager.Trust enough for Provision to type-assert +// successfully. No method bodies need to be exercised in these tests. +type stubTrust struct{} + +func (stubTrust) Apply(context.Context, *trustv1.Trust) error { return nil } +func (stubTrust) Block(context.Context, string) error { return nil } +func (stubTrust) Remove(context.Context, string) error { return nil } +func (stubTrust) Unblock(context.Context, string) error { return nil } + +var errStubTrustGet = errors.New("stub get not implemented") + +func (stubTrust) Get(context.Context, string) (*trustv1.Trust, error) { + return nil, errStubTrustGet +} + +// stubTrustModule lets us register a fake trust module with the registry under +// a custom ID for tests. +type stubTrustModule struct { + stubTrust + + id string +} + +func (s *stubTrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return s }, + } +} + +func TestModule_Registration(t *testing.T) { + info, err := sessionmanager.GetModule("service.module.grpc.trustmapping") + require.NoError(t, err) + assert.Equal(t, "service.module.grpc.trustmapping", info.ID) +} + +func TestModule_ProvisionResolvesCustomTrust(t *testing.T) { + id := "trust.module.test." + t.Name() + sessionmanager.RegisterModule(&stubTrustModule{id: id}) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&extConfig{moduleID: id}) + require.NoError(t, err) + + m := &tmmod.Module{Trust: id} + require.NoError(t, m.Provision(ctx)) +} + +func TestModule_ProvisionMissingTrustFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := &tmmod.Module{Trust: "no.such.trust.module"} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "no.such.trust.module") +} + +// extConfig is a minimal sessionmanager.ExtensionConfig used by the test to +// load fake modules through the registry. +type extConfig struct{ moduleID string } + +func (c *extConfig) Module() string { return c.moduleID } +func (c *extConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } diff --git a/internal/grpc/trustmapping.go b/modules/grpc/trustmapping/server.go similarity index 78% rename from internal/grpc/trustmapping.go rename to modules/grpc/trustmapping/server.go index a691b4bf..0359312b 100644 --- a/internal/grpc/trustmapping.go +++ b/modules/grpc/trustmapping/server.go @@ -1,4 +1,4 @@ -package grpc +package trustmapping import ( "context" @@ -16,24 +16,19 @@ import ( "github.com/openkcm/session-manager/pkg/serviceerr" ) -type TrustMappingServer struct { +type Server struct { trustmappingv1.UnimplementedServiceServer trust sessionmanager.Trust } -func NewTrustMappingServer(trust sessionmanager.Trust) *TrustMappingServer { - srv := &TrustMappingServer{ - trust: trust, - } - - return srv +func NewServer(trust sessionmanager.Trust) *Server { + return &Server{trust: trust} } -func (srv *TrustMappingServer) ApplyTrustMapping(ctx context.Context, in *trustmappingv1.ApplyTrustMappingRequest) (*trustmappingv1.ApplyTrustMappingResponse, error) { +func (srv *Server) ApplyTrustMapping(ctx context.Context, in *trustmappingv1.ApplyTrustMappingRequest) (*trustmappingv1.ApplyTrustMappingResponse, error) { oidcIn := in.GetOidc() oidc := oidcv1.OIDC_builder{ - TenantId: new(oidcIn.GetTenantId()), Issuer: new(oidcIn.GetIssuer()), JwksUri: new(oidcIn.GetJwksUri()), Audiences: oidcIn.GetAudiences(), @@ -76,7 +71,7 @@ func (srv *TrustMappingServer) ApplyTrustMapping(ctx context.Context, in *trustm // BlockTrustMapping blocks the trust for the specified tenant. // It calls the underlying service to set the trust as blocked. // Returns a response containing an optional error message if blocking fails. -func (srv *TrustMappingServer) BlockTrustMapping(ctx context.Context, req *trustmappingv1.BlockTrustMappingRequest) (*trustmappingv1.BlockTrustMappingResponse, error) { +func (srv *Server) BlockTrustMapping(ctx context.Context, req *trustmappingv1.BlockTrustMappingRequest) (*trustmappingv1.BlockTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "BlockTrustMapping called") @@ -97,7 +92,7 @@ func (srv *TrustMappingServer) BlockTrustMapping(ctx context.Context, req *trust // RemoveTrustMapping removes the trust configuration for the tenant. // It calls the underlying service to remove the trust. // Returns a respose containing an optional error message if removing fails. -func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trustmappingv1.RemoveTrustMappingRequest) (*trustmappingv1.RemoveTrustMappingResponse, error) { +func (srv *Server) RemoveTrustMapping(ctx context.Context, req *trustmappingv1.RemoveTrustMappingRequest) (*trustmappingv1.RemoveTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "RemoveTrustMapping called") @@ -109,9 +104,8 @@ func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trus msg := err.Error() resp.SetMessage(msg) return resp, status.Error(codes.Internal, "failed to remove trust: "+msg) - } else { - slogctx.Warn(ctx, "RemoveTrustMapping is called but the tenant does not exist", "error", err) } + slogctx.Warn(ctx, "RemoveTrustMapping is called but the tenant does not exist", "error", err) } resp.SetSuccess(true) @@ -121,7 +115,7 @@ func (srv *TrustMappingServer) RemoveTrustMapping(ctx context.Context, req *trus // UnblockTrustMapping unblocks the trust for the specified tenant. // It calls the underlying service to set the trust as unblocked. // Returns a response containing an optional error message if unblocking fails. -func (srv *TrustMappingServer) UnblockTrustMapping(ctx context.Context, req *trustmappingv1.UnblockTrustMappingRequest) (*trustmappingv1.UnblockTrustMappingResponse, error) { +func (srv *Server) UnblockTrustMapping(ctx context.Context, req *trustmappingv1.UnblockTrustMappingRequest) (*trustmappingv1.UnblockTrustMappingResponse, error) { ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) slogctx.Debug(ctx, "UnblockTrustMapping called") diff --git a/internal/grpc/trust_test.go b/modules/grpc/trustmapping/server_test.go similarity index 93% rename from internal/grpc/trust_test.go rename to modules/grpc/trustmapping/server_test.go index 37f55e08..c15bb74e 100644 --- a/internal/grpc/trust_test.go +++ b/modules/grpc/trustmapping/server_test.go @@ -1,4 +1,4 @@ -package grpc_test +package trustmapping_test import ( "errors" @@ -13,7 +13,7 @@ import ( oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" - "github.com/openkcm/session-manager/internal/grpc" + "github.com/openkcm/session-manager/modules/grpc/trustmapping" mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" "github.com/openkcm/session-manager/pkg/serviceerr" ) @@ -22,7 +22,7 @@ func TestNewTrustMappingServer(t *testing.T) { t.Run("creates server successfully", func(t *testing.T) { repo := mocktrust.NewInMemRepository() svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) assert.NotNil(t, server) }) @@ -34,7 +34,7 @@ func TestApplyTrustMapping(t *testing.T) { t.Run("success - creates new trust", func(t *testing.T) { repo := mocktrust.NewInMemRepository() svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) jwksUri := "https://issuer.example.com/.well-known/jwks.json" req := trustmappingv1.ApplyTrustMappingRequest_builder{ @@ -67,7 +67,7 @@ func TestApplyTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) jwksUri := "https://new-issuer.example.com/jwks.json" req := trustmappingv1.ApplyTrustMappingRequest_builder{ @@ -91,7 +91,7 @@ func TestApplyTrustMapping(t *testing.T) { mocktrust.WithCreateError(serviceerr.ErrNotFound), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) jwksUri := "https://issuer.example.com/jwks.json" req := trustmappingv1.ApplyTrustMappingRequest_builder{ @@ -117,7 +117,7 @@ func TestApplyTrustMapping(t *testing.T) { mocktrust.WithCreateError(internalErr), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) jwksUri := "https://issuer.example.com/jwks.json" req := trustmappingv1.ApplyTrustMappingRequest_builder{ @@ -152,7 +152,7 @@ func TestApplyTrustMapping(t *testing.T) { mocktrust.WithUpdateError(updateErr), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) jwksUri := "https://new-issuer.example.com/jwks.json" req := trustmappingv1.ApplyTrustMappingRequest_builder{ @@ -189,7 +189,7 @@ func TestBlockTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.BlockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -215,7 +215,7 @@ func TestBlockTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.BlockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -233,7 +233,7 @@ func TestBlockTrustMapping(t *testing.T) { mocktrust.WithGetError(serviceerr.ErrNotFound), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.BlockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -252,7 +252,7 @@ func TestBlockTrustMapping(t *testing.T) { mocktrust.WithGetError(internalErr), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.BlockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -286,7 +286,7 @@ func TestRemoveTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.RemoveTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -306,7 +306,7 @@ func TestRemoveTrustMapping(t *testing.T) { mocktrust.WithDeleteError(deleteErr), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.RemoveTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -330,7 +330,7 @@ func TestRemoveTrustMapping(t *testing.T) { mocktrust.WithDeleteError(serviceerr.ErrNotFound), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.RemoveTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -360,7 +360,7 @@ func TestUnblockTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.UnblockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -386,7 +386,7 @@ func TestUnblockTrustMapping(t *testing.T) { mocktrust.WithTrust(existingTrust), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.UnblockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -404,7 +404,7 @@ func TestUnblockTrustMapping(t *testing.T) { mocktrust.WithGetError(serviceerr.ErrNotFound), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.UnblockTrustMappingRequest_builder{ TenantId: new("tenant-123"), @@ -431,7 +431,7 @@ func TestUnblockTrustMapping(t *testing.T) { mocktrust.WithUpdateError(internalErr), ) svc := newTrust(repo) - server := grpc.NewTrustMappingServer(svc) + server := trustmapping.NewServer(svc) req := trustmappingv1.UnblockTrustMappingRequest_builder{ TenantId: new("tenant-123"), diff --git a/modules/sessionstore/valkey/module.go b/modules/sessionstore/valkey/module.go new file mode 100644 index 00000000..be27c026 --- /dev/null +++ b/modules/sessionstore/valkey/module.go @@ -0,0 +1,98 @@ +// Package valkey provides the sessionstore.module.valkey module: a +// session.Repository backed by Valkey, configured by the top-level valkey: +// config block. It registers itself with the sessionmanager module registry +// at init time and is loaded by business.Main as a top-level dependency. +package valkey + +import ( + "fmt" + + "github.com/openkcm/common-sdk/pkg/commoncfg" + "github.com/valkey-io/valkey-go" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/internal/session" + sessionvalkey "github.com/openkcm/session-manager/internal/session/valkey" +) + +const moduleID = "sessionstore.module.valkey" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the sessionstore.module.valkey module. It owns a Valkey client +// and exposes a session.Repository backed by it. +type Module struct { + *sessionvalkey.Repository + + Mod string `yaml:"module"` + Host commoncfg.SourceRef `yaml:"host"` + User commoncfg.SourceRef `yaml:"user"` + Password commoncfg.SourceRef `yaml:"password"` + Prefix string `yaml:"prefix"` + SecretRef commoncfg.SecretRef `yaml:"secretRef"` + + client valkey.Client +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(_ *sessionmanager.Context) error { + host, err := commoncfg.LoadValueFromSourceRef(m.Host) + if err != nil { + return fmt.Errorf("loading valkey host: %w", err) + } + user, err := commoncfg.LoadValueFromSourceRef(m.User) + if err != nil { + return fmt.Errorf("loading valkey user: %w", err) + } + password, err := commoncfg.LoadValueFromSourceRef(m.Password) + if err != nil { + return fmt.Errorf("loading valkey password: %w", err) + } + + opts := valkey.ClientOption{ + InitAddress: []string{string(host)}, + Username: string(user), + Password: string(password), + } + + if m.SecretRef.Type == commoncfg.MTLSSecretType { + tlsConfig, err := commoncfg.LoadMTLSConfig(&m.SecretRef.MTLS) + if err != nil { + return fmt.Errorf("loading valkey mTLS config: %w", err) + } + opts.TLSConfig = tlsConfig + } + + client, err := valkey.NewClient(opts) + if err != nil { + return fmt.Errorf("creating valkey client: %w", err) + } + m.client = client + m.Repository = sessionvalkey.NewRepository(client, m.Prefix) + + return nil +} + +func (m *Module) Close() error { + if m.client == nil { + return nil + } + m.client.Close() + return nil +} + +// Compile-time guarantee that the module satisfies session.Repository via the +// embedded *sessionvalkey.Repository. +var _ session.Repository = (*Module)(nil) diff --git a/modules/sessionstore/valkey/module_test.go b/modules/sessionstore/valkey/module_test.go new file mode 100644 index 00000000..68e118dc --- /dev/null +++ b/modules/sessionstore/valkey/module_test.go @@ -0,0 +1,27 @@ +package valkey_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sessionmanager "github.com/openkcm/session-manager" + sessionstorevalkey "github.com/openkcm/session-manager/modules/sessionstore/valkey" +) + +func TestModule_RegistrationAndID(t *testing.T) { + info, err := sessionmanager.GetModule("sessionstore.module.valkey") + require.NoError(t, err) + assert.Equal(t, "sessionstore.module.valkey", info.ID) + + mod := info.New() + require.NotNil(t, mod) + _, ok := mod.(*sessionstorevalkey.Module) + assert.True(t, ok, "New() must return *Module") +} + +func TestModule_CloseBeforeProvisionIsSafe(t *testing.T) { + m := new(sessionstorevalkey.Module) + require.NoError(t, m.Close(), "Close before Provision must not error") +} diff --git a/modules/standard/imports.go b/modules/standard/imports.go index d21957ef..05c89ce9 100644 --- a/modules/standard/imports.go +++ b/modules/standard/imports.go @@ -1,7 +1,12 @@ package standard import ( + _ "github.com/openkcm/session-manager/modules/app/grpcserver" + _ "github.com/openkcm/session-manager/modules/credentials/oauth2" _ "github.com/openkcm/session-manager/modules/database/pgxpool" + _ "github.com/openkcm/session-manager/modules/grpc/session" + _ "github.com/openkcm/session-manager/modules/grpc/trustmapping" _ "github.com/openkcm/session-manager/modules/oidctrust" _ "github.com/openkcm/session-manager/modules/oidctrust/migrations" + _ "github.com/openkcm/session-manager/modules/sessionstore/valkey" ) From 8f02e4c0b35ccf6d0990cc805b9e7fdb5c246d09 Mon Sep 17 00:00:00 2001 From: Danylo Shevchenko Date: Tue, 26 May 2026 09:02:13 +0200 Subject: [PATCH 5/5] feat: oidcmapping module Signed-off-by: Danylo Shevchenko --- charts/session-manager/values-dev.yaml | 1 + charts/session-manager/values.yaml | 1 + config.yaml | 1 + context.go | 2 +- go.mod | 2 +- go.sum | 4 +- modules/grpc/oidcmapping/import_test.go | 12 + modules/grpc/oidcmapping/module.go | 62 +++++ modules/grpc/oidcmapping/module_test.go | 77 ++++++ modules/grpc/oidcmapping/server.go | 123 ++++++++++ modules/grpc/oidcmapping/server_test.go | 300 ++++++++++++++++++++++++ modules/oidctrust/module.go | 2 +- modules/standard/imports.go | 1 + 13 files changed, 583 insertions(+), 5 deletions(-) create mode 100644 modules/grpc/oidcmapping/import_test.go create mode 100644 modules/grpc/oidcmapping/module.go create mode 100644 modules/grpc/oidcmapping/module_test.go create mode 100644 modules/grpc/oidcmapping/server.go create mode 100644 modules/grpc/oidcmapping/server_test.go diff --git a/charts/session-manager/values-dev.yaml b/charts/session-manager/values-dev.yaml index 7522ab14..68c9b199 100644 --- a/charts/session-manager/values-dev.yaml +++ b/charts/session-manager/values-dev.yaml @@ -357,3 +357,4 @@ config: services: - module: service.module.grpc.session - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/charts/session-manager/values.yaml b/charts/session-manager/values.yaml index 5ef91b5b..2683a740 100644 --- a/charts/session-manager/values.yaml +++ b/charts/session-manager/values.yaml @@ -370,3 +370,4 @@ config: services: - module: service.module.grpc.session - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/config.yaml b/config.yaml index 703c79a3..4b03f8b7 100644 --- a/config.yaml +++ b/config.yaml @@ -271,3 +271,4 @@ apps: services: - module: service.module.grpc.session - module: service.module.grpc.trustmapping + - module: service.module.grpc.oidcmapping diff --git a/context.go b/context.go index 45b5f503..825ec175 100644 --- a/context.go +++ b/context.go @@ -173,7 +173,7 @@ func (c *Context) instantiate(cfg ExtensionConfig) (Module, ModuleInfo, error) { } } - slogctx.Debug(c, "instantinated module", "module", modInfo.ID) + slogctx.Debug(c, "instantiated a module", "module", modInfo.ID) if provisioner, ok := mod.(Provisioner); ok { if err := provisioner.Provision(c); err != nil { diff --git a/go.mod b/go.mod index 1661dec5..0264d926 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/knadh/koanf/v2 v2.3.4 github.com/moby/moby/api v1.54.2 github.com/oapi-codegen/runtime v1.4.0 - github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096 + github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf github.com/openkcm/common-sdk v1.16.0 github.com/pressly/goose/v3 v3.27.1 github.com/samber/oops v1.21.0 diff --git a/go.sum b/go.sum index 8e0f2f20..1148fbbd 100644 --- a/go.sum +++ b/go.sum @@ -272,8 +272,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= -github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096 h1:k814id04b74JgTxdKzQl2+9+Th+jIzvoc3sd0tYilRE= -github.com/openkcm/api-sdk v0.17.1-0.20260522173704-546d9188a096/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= +github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf h1:2aGrIcRODxQQ/6/E9e61IZ9SV4YlpHwPmk1hPAIhd70= +github.com/openkcm/api-sdk v0.17.1-0.20260526065520-4c441e4daecf/go.mod h1:DeG8HQLN6QjzCpluI3B0xZCXqXEHv+0eSFg1+R5BQPo= github.com/openkcm/common-sdk v1.16.0 h1:pmLXRHvjqg+8ATEyzXarCRiRghw/8pXGn2OtoYuMEIU= github.com/openkcm/common-sdk v1.16.0/go.mod h1:4umveCyatAaTi6dSQgwaBg1O/wqHr4sjzuMIQhEuX1o= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= diff --git a/modules/grpc/oidcmapping/import_test.go b/modules/grpc/oidcmapping/import_test.go new file mode 100644 index 00000000..51d6268c --- /dev/null +++ b/modules/grpc/oidcmapping/import_test.go @@ -0,0 +1,12 @@ +package oidcmapping_test + +import ( + _ "unsafe" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/modules/oidctrust" + _ "github.com/openkcm/session-manager/modules/standard" +) + +//go:linkname newTrust github.com/openkcm/session-manager/modules/oidctrust.newOIDCTrustModuleWithRepo +func newTrust(r oidctrust.TrustRepository) sessionmanager.Trust diff --git a/modules/grpc/oidcmapping/module.go b/modules/grpc/oidcmapping/module.go new file mode 100644 index 00000000..5aaf629b --- /dev/null +++ b/modules/grpc/oidcmapping/module.go @@ -0,0 +1,62 @@ +// Package oidcmapping provides the service.module.grpc.oidcmapping module: +// a gRPC service module that registers the legacy +// kms.api.cmk.sessionmanager.oidcmapping.v1.Service proto onto a +// grpc.ServiceRegistrar supplied by app.module.grpcserver. It exists for +// backward compatibility with clients that have not migrated to the +// kms.api.cmk.sessionmanager.trustmapping.v1.Service proto. +package oidcmapping + +import ( + "fmt" + + "google.golang.org/grpc" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + + sessionmanager "github.com/openkcm/session-manager" +) + +const moduleID = "service.module.grpc.oidcmapping" + +func init() { + sessionmanager.RegisterModule(new(Module)) +} + +func newModule() sessionmanager.Module { + return new(Module) +} + +// Module is the service.module.grpc.oidcmapping module. It owns a Server that +// adapts the legacy oidcmapping proto onto sessionmanager.Trust and resolves +// its single dependency by ID via ctx.GetModule. +type Module struct { + Mod string `yaml:"module"` + Trust string `yaml:"trust" default:"trust.module.oidc"` + + server *Server +} + +func (m *Module) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: moduleID, + New: newModule, + } +} + +func (m *Module) Provision(ctx *sessionmanager.Context) error { + trustMod, err := ctx.GetModule(m.Trust) + if err != nil { + return fmt.Errorf("getting trust module %q: %w", m.Trust, err) + } + trust, ok := trustMod.(sessionmanager.Trust) + if !ok { + return fmt.Errorf("module %q does not implement sessionmanager.Trust", m.Trust) + } + + m.server = NewServer(trust) + return nil +} + +func (m *Module) Register(s grpc.ServiceRegistrar) { + oidcmappingv1.RegisterServiceServer(s, m.server) +} diff --git a/modules/grpc/oidcmapping/module_test.go b/modules/grpc/oidcmapping/module_test.go new file mode 100644 index 00000000..f730b323 --- /dev/null +++ b/modules/grpc/oidcmapping/module_test.go @@ -0,0 +1,77 @@ +package oidcmapping_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + sessionmanager "github.com/openkcm/session-manager" + ommod "github.com/openkcm/session-manager/modules/grpc/oidcmapping" + _ "github.com/openkcm/session-manager/modules/standard" +) + +type stubTrust struct{} + +func (stubTrust) Apply(context.Context, *trustv1.Trust) error { return nil } +func (stubTrust) Block(context.Context, string) error { return nil } +func (stubTrust) Remove(context.Context, string) error { return nil } +func (stubTrust) Unblock(context.Context, string) error { return nil } + +var errStubTrustGet = errors.New("stub get not implemented") + +func (stubTrust) Get(context.Context, string) (*trustv1.Trust, error) { + return nil, errStubTrustGet +} + +type stubTrustModule struct { + stubTrust + + id string +} + +func (s *stubTrustModule) Module() sessionmanager.ModuleInfo { + return sessionmanager.ModuleInfo{ + ID: s.id, + New: func() sessionmanager.Module { return s }, + } +} + +func TestModule_Registration(t *testing.T) { + info, err := sessionmanager.GetModule("service.module.grpc.oidcmapping") + require.NoError(t, err) + assert.Equal(t, "service.module.grpc.oidcmapping", info.ID) +} + +func TestModule_ProvisionResolvesCustomTrust(t *testing.T) { + id := "trust.module.test." + t.Name() + sessionmanager.RegisterModule(&stubTrustModule{id: id}) + + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + _, err := ctx.LoadModule(&extConfig{moduleID: id}) + require.NoError(t, err) + + m := &ommod.Module{Trust: id} + require.NoError(t, m.Provision(ctx)) +} + +func TestModule_ProvisionMissingTrustFails(t *testing.T) { + ctx, cancel := sessionmanager.NewContext(t.Context()) + defer cancel(nil) + + m := &ommod.Module{Trust: "no.such.trust.module"} + err := m.Provision(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "no.such.trust.module") +} + +type extConfig struct{ moduleID string } + +func (c *extConfig) Module() string { return c.moduleID } +func (c *extConfig) UnmarshalExtension(_ sessionmanager.Module) error { return nil } diff --git a/modules/grpc/oidcmapping/server.go b/modules/grpc/oidcmapping/server.go new file mode 100644 index 00000000..ae55e673 --- /dev/null +++ b/modules/grpc/oidcmapping/server.go @@ -0,0 +1,123 @@ +package oidcmapping + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + slogctx "github.com/veqryn/slog-context" + + sessionmanager "github.com/openkcm/session-manager" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +type Server struct { + oidcmappingv1.UnimplementedServiceServer + + trust sessionmanager.Trust +} + +func NewServer(trust sessionmanager.Trust) *Server { + return &Server{trust: trust} +} + +func (srv *Server) ApplyOIDCMapping(ctx context.Context, req *oidcmappingv1.ApplyOIDCMappingRequest) (*oidcmappingv1.ApplyOIDCMappingResponse, error) { + oidcBuilder := oidcv1.OIDC_builder{ + Issuer: new(req.GetIssuer()), + Audiences: req.GetAudiences(), + } + if req.JwksUri != nil { + oidcBuilder.JwksUri = new(req.GetJwksUri()) + } + if req.ClientId != nil { + oidcBuilder.ClientId = new(req.GetClientId()) + } + oidc := oidcBuilder.Build() + + trust := trustv1.Trust_builder{ + TenantId: new(req.GetTenantId()), + Oidc: oidc, + }.Build() + + ctx = slogctx.With(ctx, + "tenantId", trust.GetTenantId(), + "issuer", oidc.GetIssuer(), + "jwksUri", oidc.GetJwksUri(), + "audiences", oidc.GetAudiences(), + "client_id", oidc.GetClientId(), + ) + + slogctx.Debug(ctx, "ApplyOIDCMapping called") + + response := &oidcmappingv1.ApplyOIDCMappingResponse{} + + if err := srv.trust.Apply(ctx, trust); err != nil { + slogctx.Error(ctx, "Could not apply trust", "error", err) + if errors.Is(err, serviceerr.ErrNotFound) { + msg := serviceerr.ErrNotFound.Error() + response.Message = &msg + return response, nil + } + + return nil, status.Errorf(codes.Internal, "failed to apply trust: %v", err) + } + + response.Success = true + return response, nil +} + +func (srv *Server) RemoveOIDCMapping(ctx context.Context, req *oidcmappingv1.RemoveOIDCMappingRequest) (*oidcmappingv1.RemoveOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "RemoveOIDCMapping called") + + resp := &oidcmappingv1.RemoveOIDCMappingResponse{} + if err := srv.trust.Remove(ctx, req.GetTenantId()); err != nil { + if !errors.Is(err, serviceerr.ErrNotFound) { + slogctx.Error(ctx, "Could not remove trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to remove trust: "+msg) + } + slogctx.Warn(ctx, "RemoveOIDCMapping is called but the tenant does not exist", "error", err) + } + + resp.Success = true + return resp, nil +} + +func (srv *Server) BlockOIDCMapping(ctx context.Context, req *oidcmappingv1.BlockOIDCMappingRequest) (*oidcmappingv1.BlockOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "BlockOIDCMapping called") + + resp := &oidcmappingv1.BlockOIDCMappingResponse{} + if err := srv.trust.Block(ctx, req.GetTenantId()); err != nil { + slogctx.Error(ctx, "Could not block trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to block trust: "+msg) + } + + resp.Success = true + return resp, nil +} + +func (srv *Server) UnblockOIDCMapping(ctx context.Context, req *oidcmappingv1.UnblockOIDCMappingRequest) (*oidcmappingv1.UnblockOIDCMappingResponse, error) { + ctx = slogctx.With(ctx, "tenantId", req.GetTenantId()) + slogctx.Debug(ctx, "UnblockOIDCMapping called") + + resp := &oidcmappingv1.UnblockOIDCMappingResponse{} + if err := srv.trust.Unblock(ctx, req.GetTenantId()); err != nil { + slogctx.Error(ctx, "Could not unblock trust", "error", err) + msg := err.Error() + resp.Message = &msg + return resp, status.Error(codes.Internal, "failed to unblock trust: "+msg) + } + + resp.Success = true + return resp, nil +} diff --git a/modules/grpc/oidcmapping/server_test.go b/modules/grpc/oidcmapping/server_test.go new file mode 100644 index 00000000..1b3ede59 --- /dev/null +++ b/modules/grpc/oidcmapping/server_test.go @@ -0,0 +1,300 @@ +package oidcmapping_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + oidcmappingv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/sessionmanager/oidcmapping/v1" + oidcv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/oidc/v1" + trustv1 "github.com/openkcm/api-sdk/proto/kms/api/cmk/trust/v1" + + "github.com/openkcm/session-manager/modules/grpc/oidcmapping" + mocktrust "github.com/openkcm/session-manager/modules/oidctrust/mocks" + "github.com/openkcm/session-manager/pkg/serviceerr" +) + +func TestNewOIDCMappingServer(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + assert.NotNil(t, server) +} + +func TestApplyOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("forwards issuer, jwks_uri, audiences, client_id when set", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + jwksURI := "https://issuer.example.com/.well-known/jwks.json" + clientID := "client-abc" + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-123", + Issuer: "https://issuer.example.com", + JwksUri: &jwksURI, + Audiences: []string{"audience1", "audience2"}, + ClientId: &clientID, + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-123") + require.NotNil(t, stored) + assert.Equal(t, "tenant-123", stored.GetTenantId()) + require.NotNil(t, stored.GetOidc()) + assert.Equal(t, "https://issuer.example.com", stored.GetOidc().GetIssuer()) + assert.Equal(t, jwksURI, stored.GetOidc().GetJwksUri()) + assert.Equal(t, []string{"audience1", "audience2"}, stored.GetOidc().GetAudiences()) + assert.Equal(t, clientID, stored.GetOidc().GetClientId()) + assert.True(t, stored.GetOidc().HasClientId()) + }) + + t.Run("client_id omitted leaves new oidc.client_id unset", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-no-client", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-no-client") + require.NotNil(t, stored) + require.NotNil(t, stored.GetOidc()) + assert.False(t, stored.GetOidc().HasClientId(), "client_id should remain unset when request omits it") + }) + + t.Run("non-empty properties map is dropped", func(t *testing.T) { + repo := mocktrust.NewInMemRepository() + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + clientID := "client-xyz" + reqWithProps := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-with-props", + Issuer: "https://issuer.example.com", + ClientId: &clientID, + Properties: map[string]string{"foo": "bar", "baz": "qux"}, + } + + resp, err := server.ApplyOIDCMapping(ctx, reqWithProps) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + + stored := repo.TGet("tenant-with-props") + require.NotNil(t, stored) + // The new oidc.OIDC has no properties field; verify the stored Trust matches what + // we'd get by building it from the same request without properties. + expected := trustv1.Trust_builder{ + TenantId: new("tenant-with-props"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + ClientId: new(clientID), + }.Build(), + }.Build() + assert.Equal(t, expected.GetTenantId(), stored.GetTenantId()) + assert.Equal(t, expected.GetOidc().GetIssuer(), stored.GetOidc().GetIssuer()) + assert.Equal(t, expected.GetOidc().GetClientId(), stored.GetOidc().GetClientId()) + }) + + t.Run("ErrNotFound from Apply yields non-success response with message and no gRPC error", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-missing", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, resp.GetSuccess()) + assert.Equal(t, serviceerr.ErrNotFound.Error(), resp.GetMessage()) + }) + + t.Run("other errors map to codes.Internal", func(t *testing.T) { + internalErr := errors.New("database connection failed") + repo := mocktrust.NewInMemRepository( + mocktrust.WithCreateError(internalErr), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.ApplyOIDCMappingRequest{ + TenantId: "tenant-boom", + Issuer: "https://issuer.example.com", + } + + resp, err := server.ApplyOIDCMapping(ctx, req) + assert.Nil(t, resp) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to apply trust") + }) +} + +func TestRemoveOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success removes existing trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("ErrNotFound is idempotent and returns success", func(t *testing.T) { + repo := mocktrust.NewInMemRepository( + mocktrust.WithDeleteError(serviceerr.ErrNotFound), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-gone"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + }) + + t.Run("other errors map to codes.Internal", func(t *testing.T) { + deleteErr := errors.New("delete failed") + repo := mocktrust.NewInMemRepository(mocktrust.WithDeleteError(deleteErr)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.RemoveOIDCMappingRequest{TenantId: "tenant-boom"} + resp, err := server.RemoveOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "delete failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to remove trust") + }) +} + +func TestBlockOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success blocks existing trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(false), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.BlockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.BlockOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("error maps to codes.Internal with message", func(t *testing.T) { + internalErr := errors.New("database error") + repo := mocktrust.NewInMemRepository(mocktrust.WithGetError(internalErr)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.BlockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.BlockOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "database error") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to block trust") + }) +} + +func TestUnblockOIDCMapping(t *testing.T) { + ctx := t.Context() + + t.Run("success unblocks blocked trust", func(t *testing.T) { + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository(mocktrust.WithTrust(existing)) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.UnblockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.UnblockOIDCMapping(ctx, req) + require.NoError(t, err) + assert.True(t, resp.GetSuccess()) + assert.Empty(t, resp.GetMessage()) + }) + + t.Run("error maps to codes.Internal with message", func(t *testing.T) { + internalErr := errors.New("update failed") + existing := trustv1.Trust_builder{ + TenantId: new("tenant-123"), + Blocked: new(true), + Oidc: oidcv1.OIDC_builder{ + Issuer: new("https://issuer.example.com"), + }.Build(), + }.Build() + repo := mocktrust.NewInMemRepository( + mocktrust.WithTrust(existing), + mocktrust.WithUpdateError(internalErr), + ) + svc := newTrust(repo) + server := oidcmapping.NewServer(svc) + + req := &oidcmappingv1.UnblockOIDCMappingRequest{TenantId: "tenant-123"} + resp, err := server.UnblockOIDCMapping(ctx, req) + require.Error(t, err) + assert.NotNil(t, resp) + assert.Contains(t, resp.GetMessage(), "update failed") + + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "failed to unblock trust") + }) +} diff --git a/modules/oidctrust/module.go b/modules/oidctrust/module.go index 010575fa..cc209b13 100644 --- a/modules/oidctrust/module.go +++ b/modules/oidctrust/module.go @@ -17,7 +17,7 @@ func init() { sessionmanager.RegisterModule(new(TrustModule)) } -// TrustModule is a module that implements sessionmanager.Trust interface. It's using a database providede by the +// TrustModule is a module that implements sessionmanager.Trust interface. It's using a database provided by the // [dbModule] module which implements sessionmanager.DBModule. type TrustModule struct { DBModule string `yaml:"dbModule" default:"database.module.pgxpool"` diff --git a/modules/standard/imports.go b/modules/standard/imports.go index 05c89ce9..e6d3c06c 100644 --- a/modules/standard/imports.go +++ b/modules/standard/imports.go @@ -4,6 +4,7 @@ import ( _ "github.com/openkcm/session-manager/modules/app/grpcserver" _ "github.com/openkcm/session-manager/modules/credentials/oauth2" _ "github.com/openkcm/session-manager/modules/database/pgxpool" + _ "github.com/openkcm/session-manager/modules/grpc/oidcmapping" _ "github.com/openkcm/session-manager/modules/grpc/session" _ "github.com/openkcm/session-manager/modules/grpc/trustmapping" _ "github.com/openkcm/session-manager/modules/oidctrust"