From 1c08be26da3b5ea181fd791df7bf07d78f967d05 Mon Sep 17 00:00:00 2001 From: petruki <31597636+petruki@users.noreply.github.com> Date: Sun, 17 May 2026 14:01:04 -0700 Subject: [PATCH] feat: added throttle and execution logger --- README.md | 14 ++ client.go | 67 +++++++- execution_logger.go | 148 +++++++++++++++++ execution_logger_test.go | 190 ++++++++++++++++++++++ remote.go | 7 +- switcher.go | 196 +++++++++++++++++++++-- switcher_throttle_test.go | 326 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 923 insertions(+), 25 deletions(-) create mode 100644 execution_logger.go create mode 100644 execution_logger_test.go create mode 100644 switcher_throttle_test.go diff --git a/README.md b/README.md index 963daa6..964f0ed 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ A Go SDK for Switcher API [![Master CI](https://github.com/switcherapi/switcher-client-go/actions/workflows/master.yml/badge.svg)](https://github.com/switcherapi/switcher-client-go/actions/workflows/master.yml) [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=switcherapi_switcher-client-go&metric=alert_status)](https://sonarcloud.io/dashboard?id=switcherapi_switcher-client-go) +![Known Vulnerabilities](https://snyk.io/test/github/switcherapi/switcher-client-go/badge.svg) [![Go Report Card](https://goreportcard.com/badge/github.com/switcherapi/switcher-client-go)](https://goreportcard.com/report/github.com/switcherapi/switcher-client-go) ![Go](https://img.shields.io/badge/go-1.25%2B-blue.svg) ![Status](https://img.shields.io/badge/status-under_development-orange.svg) @@ -325,6 +326,9 @@ client.SubscribeNotifyError(func(err error) { ## Advanced Features #### Throttling + +Throttle implements Stale-While-Revalidate behavior for feature flag evaluations, returning cached results while refreshing in the background. This is ideal for high-traffic scenarios where you want to minimize latency and avoid overwhelming the API with requests. + ```go _, err := client.GetSwitcher("FEATURE01").Throttle(time.Second).IsOn() if err != nil { @@ -332,6 +336,16 @@ if err != nil { } ``` +Throttle reuses the latest cached execution for the same switcher key and inputs. It records that cached execution even when `ContextOptions.Logger` is `false`, and when `Freeze` is enabled the cached value stays in place until `client.ClearLogger()` is called. + +```go +switcher := client.GetSwitcher("FEATURE01").Throttle(time.Second) +_, _ = switcher.IsOnWithDetails() + +logged := client.GetExecution(switcher) +fmt.Println(logged.Response.Metadata["cached"]) +``` + #### Hybrid Mode ```go _, err := client.GetSwitcher("FEATURE01").Remote().IsOn() diff --git a/client.go b/client.go index 0b829b3..a82f5b2 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,9 @@ type Client struct { switchers map[string]*Switcher snapshot *Snapshot + executionLogger *executionLogger + throttleTokens chan struct{} + snapshotWatcher *snapshotWatcher snapshotAutoUpdater *snapshotAutoUpdater @@ -31,9 +34,12 @@ type Client struct { } func NewClient(ctx Context) *Client { + defaulted := ctx.withDefaults() return &Client{ - context: ctx.withDefaults(), + context: defaulted, switchers: make(map[string]*Switcher), + executionLogger: newExecutionLogger(), + throttleTokens: newThrottleTokens(defaulted.Options.ThrottleMaxWorkers), snapshotWatcher: newSnapshotWatcher(), snapshotAutoUpdater: newSnapshotAutoUpdater(), } @@ -49,6 +55,13 @@ func BuildContext(ctx Context) { client.ScheduleSnapshotAutoUpdate(0, nil) } +func (c *Client) Context() Context { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.context +} + func GetSwitcher(key string) *Switcher { return defaultClient().GetSwitcher(key) } @@ -87,13 +100,6 @@ func (c *Client) GetSwitcher(key string) *Switcher { return switcher } -func (c *Client) Context() Context { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.context -} - func LoadSnapshot(options *LoadSnapshotOptions) (int, error) { return defaultClient().LoadSnapshot(options) } @@ -174,6 +180,27 @@ func (c *Client) CheckSnapshot() (bool, error) { return true, nil } +func GetExecution(switcher *Switcher) ExecutionEntry { + return defaultClient().GetExecution(switcher) +} + +func (c *Client) GetExecution(switcher *Switcher) ExecutionEntry { + if switcher == nil { + return ExecutionEntry{} + } + + execution := switcher.snapshotForExecution() + return c.executionLogger.get(execution.key, execution.entries) +} + +func ClearLogger() { + defaultClient().ClearLogger() +} + +func (c *Client) ClearLogger() { + c.executionLogger.clear() +} + func SubscribeNotifyError(callback func(error)) { defaultClient().SubscribeNotifyError(callback) } @@ -195,6 +222,22 @@ func (c *Client) notifyError(err error) { } } +func (c *Client) runBackgroundTask(task func()) { + if c.throttleTokens == nil { + go task() + return + } + + go func() { + c.throttleTokens <- struct{}{} + defer func() { + <-c.throttleTokens + }() + + task() + }() +} + func defaultClient() *Client { if client := globalClient.Load(); client != nil { return client @@ -210,3 +253,11 @@ func defaultClient() *Client { return globalClient.Load() } + +func newThrottleTokens(maxWorkers int) chan struct{} { + if maxWorkers <= 0 { + return nil + } + + return make(chan struct{}, maxWorkers) +} diff --git a/execution_logger.go b/execution_logger.go new file mode 100644 index 0000000..6e87aeb --- /dev/null +++ b/execution_logger.go @@ -0,0 +1,148 @@ +package client + +import ( + "maps" + "sync" +) + +type ExecutionInput struct { + Strategy string + Input string +} + +type ExecutionEntry struct { + Key string + Inputs []ExecutionInput + Response ResultDetail +} + +type executionLogger struct { + mu sync.RWMutex + entries []ExecutionEntry +} + +func newExecutionLogger() *executionLogger { + return &executionLogger{ + entries: make([]ExecutionEntry, 0), + } +} + +func (l *executionLogger) add(key string, inputs []criteriaEntry, response ResultDetail) { + l.mu.Lock() + defer l.mu.Unlock() + + for i := range l.entries { + if executionEntryMatches(l.entries[i], key, inputs) { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + break + } + } + + l.entries = append(l.entries, ExecutionEntry{ + Key: key, + Inputs: executionInputsFromCriteria(inputs), + Response: cachedResultDetail(response), + }) +} + +func (l *executionLogger) get(key string, inputs []criteriaEntry) ExecutionEntry { + l.mu.RLock() + defer l.mu.RUnlock() + + for _, entry := range l.entries { + if executionEntryMatches(entry, key, inputs) { + return cloneExecutionEntry(entry) + } + } + + return ExecutionEntry{} +} + +func (l *executionLogger) clear() { + l.mu.Lock() + defer l.mu.Unlock() + + l.entries = l.entries[:0] +} + +func executionEntryMatches(entry ExecutionEntry, key string, inputs []criteriaEntry) bool { + return entry.Key == key && executionInputsMatch(entry.Inputs, inputs) +} + +func executionInputsMatch(logged []ExecutionInput, current []criteriaEntry) bool { + if len(logged) == 0 { + return len(current) == 0 + } + + if len(current) == 0 { + return false + } + + for _, loggedInput := range logged { + found := false + for _, currentInput := range current { + if currentInput.Strategy == loggedInput.Strategy && currentInput.Input == loggedInput.Input { + found = true + break + } + } + + if !found { + return false + } + } + + return true +} + +func executionInputsFromCriteria(inputs []criteriaEntry) []ExecutionInput { + if len(inputs) == 0 { + return nil + } + + converted := make([]ExecutionInput, len(inputs)) + for i, input := range inputs { + converted[i] = ExecutionInput(input) + } + + return converted +} + +func cloneExecutionEntry(entry ExecutionEntry) ExecutionEntry { + return ExecutionEntry{ + Key: entry.Key, + Inputs: cloneExecutionInputs(entry.Inputs), + Response: cloneResultDetail(entry.Response), + } +} + +func cloneExecutionInputs(inputs []ExecutionInput) []ExecutionInput { + if len(inputs) == 0 { + return nil + } + + cloned := make([]ExecutionInput, len(inputs)) + copy(cloned, inputs) + return cloned +} + +func cloneResultDetail(result ResultDetail) ResultDetail { + return ResultDetail{ + Result: result.Result, + Reason: result.Reason, + Metadata: cloneMetadata(result.Metadata), + } +} + +func cachedResultDetail(result ResultDetail) ResultDetail { + cached := cloneResultDetail(result) + cached.Metadata["cached"] = true + return cached +} + +func cloneMetadata(metadata map[string]any) map[string]any { + cloned := make(map[string]any, len(metadata)) + maps.Copy(cloned, metadata) + + return cloned +} diff --git a/execution_logger_test.go b/execution_logger_test.go new file mode 100644 index 0000000..b299045 --- /dev/null +++ b/execution_logger_test.go @@ -0,0 +1,190 @@ +package client + +import ( + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestExecutionLogger(t *testing.T) { + t.Run("should return an empty execution when the package API receives a nil switcher", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + }) + + logged := GetExecution(nil) + + assert.Equal(t, ExecutionEntry{}, logged) + }) + + t.Run("should log remote executions and retrieve them by switcher", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: ContextOptions{ + Logger: true, + }, + }) + + got, err := client.GetSwitcher("MY_SWITCHER").CheckValue("user_id").IsOn() + + assert.NoError(t, err) + assert.True(t, got) + + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER").CheckValue("user_id")) + assert.Equal(t, "MY_SWITCHER", logged.Key) + assert.Equal(t, []ExecutionInput{{Strategy: StrategyValue, Input: "user_id"}}, logged.Inputs) + assert.True(t, logged.Response.Result) + assert.Equal(t, map[string]any{"cached": true}, logged.Response.Metadata) + }) + + t.Run("should return empty execution when inputs do not match", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: ContextOptions{ + Logger: true, + }, + }) + + _, err := client.GetSwitcher("MY_SWITCHER").CheckValue("user_id").IsOn() + assert.NoError(t, err) + + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER").CheckValue("other_id")) + assert.Equal(t, ExecutionEntry{}, logged) + }) + + t.Run("should return empty execution when the logged entry has inputs and the lookup has none", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: ContextOptions{ + Logger: true, + }, + }) + + _, err := client.GetSwitcher("MY_SWITCHER").CheckValue("user_id").IsOn() + assert.NoError(t, err) + + lookup := client.GetSwitcher("") + err = lookup.Prepare("MY_SWITCHER") + assert.NoError(t, err) + + logged := client.GetExecution(lookup) + assert.Equal(t, ExecutionEntry{}, logged) + }) + + t.Run("should clear logged executions", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: ContextOptions{ + Logger: true, + }, + }) + + _, err := client.GetSwitcher("MY_SWITCHER").IsOn() + assert.NoError(t, err) + + client.ClearLogger() + + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + assert.Equal(t, ExecutionEntry{}, logged) + }) + + t.Run("should clear logged executions through the package API", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + }) + defer server.Close() + + BuildContext(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: ContextOptions{ + Logger: true, + }, + }) + + _, err := GetSwitcher("MY_SWITCHER").IsOn() + assert.NoError(t, err) + + ClearLogger() + + logged := GetExecution(GetSwitcher("MY_SWITCHER")) + assert.Equal(t, ExecutionEntry{}, logged) + }) + + t.Run("should log local executions when logger is enabled", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + Environment: "default", + Options: ContextOptions{ + Local: true, + Logger: true, + SnapshotLocation: filepath.Join("tests", "snapshots"), + }, + }) + + _, err := LoadSnapshot(nil) + assert.NoError(t, err) + + got, evalErr := GetSwitcher("FF2FOR2022").IsOn() + + assert.NoError(t, evalErr) + assert.True(t, got) + + logged := GetExecution(GetSwitcher("FF2FOR2022")) + assert.Equal(t, "FF2FOR2022", logged.Key) + assert.True(t, logged.Response.Result) + assert.Equal(t, map[string]any{"cached": true}, logged.Response.Metadata) + }) +} diff --git a/remote.go b/remote.go index 74efd21..98cedd8 100644 --- a/remote.go +++ b/remote.go @@ -112,16 +112,11 @@ func (c *Client) checkCriteria(token string, switcher *Switcher, showDetails boo query.Set("showReason", strings.ToLower(strconvFormatBool(showDetails))) query.Set("key", switcher.key) - entries := switcher.entries - if entries == nil { - entries = []criteriaEntry{} - } - response, err := c.doJSONRequest( http.MethodPost, endpoint+"?"+query.Encode(), map[string]any{ - "entry": entries, + "entry": switcher.entries, }, c.authHeaders(token), ) diff --git a/switcher.go b/switcher.go index d75175b..3983575 100644 --- a/switcher.go +++ b/switcher.go @@ -3,14 +3,27 @@ package client import ( "fmt" "strings" + "sync" + "time" ) type Switcher struct { - client *Client - key string - entries []criteriaEntry + client *Client + key string + entries []criteriaEntry + throttlePeriod time.Duration + nextRefreshAt time.Time + mu sync.RWMutex } +type executionMode uint8 + +const ( + executionModeLocal executionMode = iota + executionModeSilentLocal + executionModeRemote +) + func (s *Switcher) Validate() error { ctx := s.client.Context() missingFields := make([]string, 0, 3) @@ -42,6 +55,9 @@ func (s *Switcher) Validate() error { } func (s *Switcher) CheckValue(input string) *Switcher { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = upsertEntry(s.entries, criteriaEntry{ Strategy: StrategyValue, Input: input, @@ -51,6 +67,9 @@ func (s *Switcher) CheckValue(input string) *Switcher { } func (s *Switcher) CheckNetwork(input string) *Switcher { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = upsertEntry(s.entries, criteriaEntry{ Strategy: StrategyNetwork, Input: input, @@ -59,12 +78,32 @@ func (s *Switcher) CheckNetwork(input string) *Switcher { return s } +func (s *Switcher) Throttle(period time.Duration) *Switcher { + s.mu.Lock() + defer s.mu.Unlock() + + s.throttlePeriod = period + if period <= 0 { + s.nextRefreshAt = time.Time{} + return s + } + + if s.nextRefreshAt.IsZero() { + s.nextRefreshAt = time.Now().Add(period) + } + + return s +} + func (s *Switcher) Prepare(key string) error { if strings.TrimSpace(key) != "" { + s.mu.Lock() s.key = key + s.mu.Unlock() } - if err := s.Validate(); err != nil { + execution := s.snapshotForExecution() + if err := execution.Validate(); err != nil { return err } @@ -90,24 +129,83 @@ func (s *Switcher) IsOnWithDetails() (ResultDetail, error) { } func (s *Switcher) submit(showDetails bool) (ResultDetail, error) { + execution := s.snapshotForExecution() + if cached, ok := s.tryCachedResult(execution, showDetails); ok { + return cached, nil + } + + result, err := s.execute(execution, showDetails) + if err != nil { + return ResultDetail{}, err + } + + s.markFreshExecution() + return result, nil +} + +func (s *Switcher) tryCachedResult(execution *Switcher, showDetails bool) (ResultDetail, bool) { + if !execution.hasThrottle() { + return ResultDetail{}, false + } + + entry := execution.client.executionLogger.get(execution.key, execution.entries) + if entry.Key == "" { + return ResultDetail{}, false + } + + if !execution.client.Context().Options.Freeze && s.shouldScheduleRefresh(time.Now()) { + s.scheduleBackgroundRefresh(execution, showDetails) + } + + return entry.Response, true +} + +func (s *Switcher) execute(execution *Switcher, showDetails bool) (ResultDetail, error) { + mode, err := execution.resolveExecutionMode() + if err != nil { + return ResultDetail{}, err + } + + result, err := execution.executeMode(mode, showDetails) + if err != nil { + return ResultDetail{}, err + } + + execution.logResult(result) + return result, nil +} + +func (s *Switcher) resolveExecutionMode() (executionMode, error) { if s.client.Context().Options.Local { - return checkLocalCriteria(s.client.snapshotState(), s) + return executionModeLocal, nil } if err := s.Validate(); err != nil { - return ResultDetail{}, err + return executionModeRemote, err } if s.client.shouldUseLocalSilentMode() { - return checkLocalCriteria(s.client.snapshotState(), s) + return executionModeSilentLocal, nil } - token, err := s.client.ensureToken() - if err != nil { - return s.client.fallbackToSilentMode(s, err) + return executionModeRemote, nil +} + +func (s *Switcher) executeMode(mode executionMode, showDetails bool) (ResultDetail, error) { + if mode == executionModeLocal || mode == executionModeSilentLocal { + return s.executeLocal() } - if err := missingTokenError(token); err != nil { + return s.executeRemote(showDetails) +} + +func (s *Switcher) executeLocal() (ResultDetail, error) { + return checkLocalCriteria(s.client.snapshotState(), s) +} + +func (s *Switcher) executeRemote(showDetails bool) (ResultDetail, error) { + token, err := s.remoteToken() + if err != nil { return s.client.fallbackToSilentMode(s, err) } @@ -119,6 +217,82 @@ func (s *Switcher) submit(showDetails bool) (ResultDetail, error) { return result, nil } +func (s *Switcher) remoteToken() (string, error) { + token, err := s.client.ensureToken() + if err != nil { + return "", err + } + + if err := missingTokenError(token); err != nil { + return "", err + } + + return token, nil +} + +func (s *Switcher) snapshotForExecution() *Switcher { + s.mu.RLock() + defer s.mu.RUnlock() + + clonedEntries := make([]criteriaEntry, len(s.entries)) + copy(clonedEntries, s.entries) + + return &Switcher{ + client: s.client, + key: s.key, + entries: clonedEntries, + throttlePeriod: s.throttlePeriod, + nextRefreshAt: s.nextRefreshAt, + } +} + +func (s *Switcher) logResult(result ResultDetail) { + if !s.canLog() { + return + } + + s.client.executionLogger.add(s.key, s.entries, result) +} + +func (s *Switcher) canLog() bool { + return strings.TrimSpace(s.key) != "" && (s.client.Context().Options.Logger || s.hasThrottle()) +} + +func (s *Switcher) hasThrottle() bool { + return s.throttlePeriod > 0 +} + +func (s *Switcher) markFreshExecution() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.throttlePeriod <= 0 { + return + } + + s.nextRefreshAt = time.Now().Add(s.throttlePeriod) +} + +func (s *Switcher) shouldScheduleRefresh(now time.Time) bool { + s.mu.Lock() + defer s.mu.Unlock() + + if s.throttlePeriod <= 0 || s.nextRefreshAt.IsZero() || !now.After(s.nextRefreshAt) { + return false + } + + s.nextRefreshAt = now.Add(s.throttlePeriod) + return true +} + +func (s *Switcher) scheduleBackgroundRefresh(execution *Switcher, showDetails bool) { + s.client.runBackgroundTask(func() { + if _, err := s.execute(execution, showDetails); err != nil { + s.client.notifyError(err) + } + }) +} + func upsertEntry(entries []criteriaEntry, next criteriaEntry) []criteriaEntry { for i := range entries { if entries[i].Strategy == next.Strategy { diff --git a/switcher_throttle_test.go b/switcher_throttle_test.go new file mode 100644 index 0000000..143b500 --- /dev/null +++ b/switcher_throttle_test.go @@ -0,0 +1,326 @@ +package client + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSwitcherThrottle(t *testing.T) { + t.Run("should reuse cached results during the throttle window", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + ThrottleMaxWorkers: 1, + }) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(50 * time.Millisecond) + + first, firstErr := switcher.IsOnWithDetails() + second, secondErr := switcher.IsOnWithDetails() + + assert.NoError(t, firstErr) + assert.NoError(t, secondErr) + assert.True(t, first.Result) + assert.Equal(t, map[string]any{}, first.Metadata) + assert.True(t, second.Result) + assert.Equal(t, map[string]any{"cached": true}, second.Metadata) + assert.Equal(t, int32(1), criteriaRequests.Load()) + }) + + t.Run("should disable throttle when the period is zero", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + {"result": false}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{}) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(0) + + first, firstErr := switcher.IsOnWithDetails() + second, secondErr := switcher.IsOnWithDetails() + + assert.NoError(t, firstErr) + assert.NoError(t, secondErr) + assert.True(t, first.Result) + assert.Equal(t, map[string]any{}, first.Metadata) + assert.False(t, second.Result) + assert.Equal(t, map[string]any{}, second.Metadata) + assert.Equal(t, int32(2), criteriaRequests.Load()) + assert.Equal(t, ExecutionEntry{}, client.GetExecution(client.GetSwitcher("MY_SWITCHER"))) + }) + + t.Run("should not reset the refresh window when throttle is applied again to the cached switcher", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + {"result": false}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + ThrottleMaxWorkers: 1, + }) + + first, firstErr := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond).IsOnWithDetails() + assert.NoError(t, firstErr) + assert.True(t, first.Result) + + time.Sleep(40 * time.Millisecond) + + cached, cachedErr := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond).IsOnWithDetails() + assert.NoError(t, cachedErr) + assert.True(t, cached.Result) + assert.Equal(t, map[string]any{"cached": true}, cached.Metadata) + + assert.Eventually(t, func() bool { + return criteriaRequests.Load() == 2 + }, time.Second, 10*time.Millisecond) + + assert.Eventually(t, func() bool { + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + return logged.Key == "MY_SWITCHER" && !logged.Response.Result + }, time.Second, 10*time.Millisecond) + }) + + t.Run("should refresh cached results in the background after the throttle expires", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + {"result": false}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + ThrottleMaxWorkers: 1, + }) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond) + + first, firstErr := switcher.IsOnWithDetails() + assert.NoError(t, firstErr) + assert.True(t, first.Result) + + time.Sleep(40 * time.Millisecond) + + cached, cachedErr := switcher.IsOnWithDetails() + assert.NoError(t, cachedErr) + assert.True(t, cached.Result) + assert.Equal(t, map[string]any{"cached": true}, cached.Metadata) + + assert.Eventually(t, func() bool { + return criteriaRequests.Load() == 2 + }, time.Second, 10*time.Millisecond) + + assert.Eventually(t, func() bool { + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + return logged.Key == "MY_SWITCHER" && !logged.Response.Result + }, time.Second, 10*time.Millisecond) + + refreshed, refreshedErr := switcher.IsOnWithDetails() + assert.NoError(t, refreshedErr) + assert.False(t, refreshed.Result) + assert.Equal(t, map[string]any{"cached": true}, refreshed.Metadata) + }) + + t.Run("should notify subscribed errors when background refresh fails", func(t *testing.T) { + var criteriaRequests atomic.Int32 + mux := http.NewServeMux() + mux.HandleFunc("/criteria/auth", func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(t, http.MethodPost, request.Method) + writeJSONResponse(t, writer, http.StatusOK, map[string]any{ + "token": "[token]", + "exp": time.Now().Add(time.Hour).Unix(), + }) + }) + mux.HandleFunc("/criteria", func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(t, http.MethodPost, request.Method) + + var body map[string]any + err := json.NewDecoder(request.Body).Decode(&body) + assert.NoError(t, err) + + if criteriaRequests.Add(1) == 1 { + writeJSONResponse(t, writer, http.StatusOK, map[string]any{"result": true}) + return + } + + writeJSONResponse(t, writer, http.StatusInternalServerError, map[string]any{"error": "boom"}) + }) + + server := httptest.NewServer(mux) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + ThrottleMaxWorkers: 1, + }) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond) + + errCh := make(chan error, 1) + client.SubscribeNotifyError(func(err error) { + errCh <- err + }) + + first, firstErr := switcher.IsOnWithDetails() + assert.NoError(t, firstErr) + assert.True(t, first.Result) + + time.Sleep(40 * time.Millisecond) + + cached, cachedErr := switcher.IsOnWithDetails() + assert.NoError(t, cachedErr) + assert.True(t, cached.Result) + assert.Equal(t, map[string]any{"cached": true}, cached.Metadata) + + select { + case err := <-errCh: + var remoteCriteriaErr *RemoteCriteriaError + assert.ErrorAs(t, err, &remoteCriteriaErr) + assert.EqualError(t, err, "[check_criteria] failed with status: 500") + case <-time.After(time.Second): + t.Fatal("expected background refresh error notification") + } + + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + assert.Equal(t, "MY_SWITCHER", logged.Key) + assert.True(t, logged.Response.Result) + assert.Equal(t, map[string]any{"cached": true}, logged.Response.Metadata) + }) + + t.Run("should refresh cached results in the background without a worker limit", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + {"result": false}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{}) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond) + + first, firstErr := switcher.IsOnWithDetails() + assert.NoError(t, firstErr) + assert.True(t, first.Result) + + time.Sleep(40 * time.Millisecond) + + cached, cachedErr := switcher.IsOnWithDetails() + assert.NoError(t, cachedErr) + assert.True(t, cached.Result) + assert.Equal(t, map[string]any{"cached": true}, cached.Metadata) + + assert.Eventually(t, func() bool { + return criteriaRequests.Load() == 2 + }, time.Second, 10*time.Millisecond) + + assert.Eventually(t, func() bool { + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + return logged.Key == "MY_SWITCHER" && !logged.Response.Result + }, time.Second, 10*time.Millisecond) + }) + + t.Run("should keep serving the frozen cached result until the logger is cleared", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + {"result": false}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + Freeze: true, + ThrottleMaxWorkers: 1, + }) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(30 * time.Millisecond) + + first, firstErr := switcher.IsOnWithDetails() + assert.NoError(t, firstErr) + assert.True(t, first.Result) + + time.Sleep(40 * time.Millisecond) + + frozen, frozenErr := switcher.IsOnWithDetails() + assert.NoError(t, frozenErr) + assert.True(t, frozen.Result) + assert.Equal(t, map[string]any{"cached": true}, frozen.Metadata) + assert.Equal(t, int32(1), criteriaRequests.Load()) + + client.ClearLogger() + + refreshed, refreshedErr := switcher.IsOnWithDetails() + assert.NoError(t, refreshedErr) + assert.False(t, refreshed.Result) + assert.Equal(t, map[string]any{}, refreshed.Metadata) + assert.Equal(t, int32(2), criteriaRequests.Load()) + }) + + t.Run("should use cached results for throttle even when logger is disabled", func(t *testing.T) { + server, criteriaRequests := newThrottleTestServer(t, []map[string]any{ + {"result": true}, + }) + defer server.Close() + + client := newThrottleTestClient(server.URL, ContextOptions{ + Logger: false, + ThrottleMaxWorkers: 1, + }) + switcher := client.GetSwitcher("MY_SWITCHER").Throttle(50 * time.Millisecond) + + _, firstErr := switcher.IsOnWithDetails() + second, secondErr := switcher.IsOnWithDetails() + + assert.NoError(t, firstErr) + assert.NoError(t, secondErr) + assert.True(t, second.Result) + assert.Equal(t, map[string]any{"cached": true}, second.Metadata) + assert.Equal(t, int32(1), criteriaRequests.Load()) + + logged := client.GetExecution(client.GetSwitcher("MY_SWITCHER")) + assert.Equal(t, "MY_SWITCHER", logged.Key) + assert.True(t, logged.Response.Result) + assert.Equal(t, map[string]any{"cached": true}, logged.Response.Metadata) + }) +} + +func newThrottleTestClient(serverURL string, options ContextOptions) *Client { + return NewClient(Context{ + Domain: "My Domain", + URL: serverURL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Options: options, + }) +} + +func newThrottleTestServer(t *testing.T, criteriaResponses []map[string]any) (*httptest.Server, *atomic.Int32) { + t.Helper() + + var criteriaRequests atomic.Int32 + mux := http.NewServeMux() + mux.HandleFunc("/criteria/auth", func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(t, http.MethodPost, request.Method) + writeJSONResponse(t, writer, http.StatusOK, map[string]any{ + "token": "[token]", + "exp": time.Now().Add(time.Hour).Unix(), + }) + }) + mux.HandleFunc("/criteria", func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(t, http.MethodPost, request.Method) + + var body map[string]any + err := json.NewDecoder(request.Body).Decode(&body) + assert.NoError(t, err) + + index := int(criteriaRequests.Add(1)) - 1 + if index >= len(criteriaResponses) { + index = len(criteriaResponses) - 1 + } + + writeJSONResponse(t, writer, http.StatusOK, criteriaResponses[index]) + }) + + return httptest.NewServer(mux), &criteriaRequests +}