From 540485f3332b258a69cd3c34ad425ea41cb26e30 Mon Sep 17 00:00:00 2001 From: Thomas Maurer Date: Thu, 2 Apr 2026 20:52:04 +0200 Subject: [PATCH 1/6] refactor: switch cache to per-register storage Cache keys now address individual registers/coils instead of request ranges. This fixes stale data when writes hit registers that are part of a larger cached read range. Key changes: - RegKey(slaveID, fc, addr) for per-register storage - RangeKey(slaveID, fc, addr, qty) for request coalescing - GetRange/SetRange/DeleteRange for batch operations - Rename GetOrFetch to Coalesce (no longer interacts with cache) - Add keepStale flag to prevent cleanup from removing entries needed for stale-serve fallback --- internal/cache/cache.go | 130 +++++++++++++++----- internal/cache/cache_test.go | 222 +++++++++++++++++++++++++++++------ 2 files changed, 286 insertions(+), 66 deletions(-) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index f76648e..5ff1597 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -20,11 +20,12 @@ func (e *Entry) IsExpired() bool { return time.Since(e.Timestamp) > e.TTL } -// Cache is a thread-safe in-memory cache with TTL. +// Cache is a thread-safe in-memory cache with TTL and per-register storage. type Cache struct { mu sync.RWMutex entries map[string]*Entry defaultTTL time.Duration + keepStale bool // when true, cleanup won't delete expired entries // For request coalescing inflight map[string]*inflightRequest @@ -41,10 +42,12 @@ type inflightRequest struct { } // New creates a new cache with the specified default TTL. -func New(defaultTTL time.Duration) *Cache { +// If keepStale is true, expired entries are retained for stale serving. +func New(defaultTTL time.Duration, keepStale bool) *Cache { c := &Cache{ entries: make(map[string]*Entry), defaultTTL: defaultTTL, + keepStale: keepStale, inflight: make(map[string]*inflightRequest), done: make(chan struct{}), } @@ -60,9 +63,14 @@ func (c *Cache) Close() { close(c.done) } -// Key generates a cache key from request parameters. -func Key(slaveID byte, functionCode byte, address uint16, quantity uint16) string { - return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, address, quantity) +// RegKey generates a cache key for a single register or coil. +func RegKey(slaveID byte, functionCode byte, address uint16) string { + return fmt.Sprintf("%d:%d:%d", slaveID, functionCode, address) +} + +// RangeKey generates a coalescing key for a request range. +func RangeKey(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) string { + return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, startAddr, quantity) } // Get retrieves a value from the cache. @@ -100,11 +108,6 @@ func (c *Cache) GetStale(key string) ([]byte, bool) { // Set stores a value in the cache with the default TTL. func (c *Cache) Set(key string, data []byte) { - c.SetWithTTL(key, data, c.defaultTTL) -} - -// SetWithTTL stores a value in the cache with a specific TTL. -func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.mu.Lock() defer c.mu.Unlock() @@ -115,7 +118,7 @@ func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.entries[key] = &Entry{ Data: dataCopy, Timestamp: time.Now(), - TTL: ttl, + TTL: c.defaultTTL, } } @@ -126,17 +129,80 @@ func (c *Cache) Delete(key string) { delete(c.entries, key) } -// GetOrFetch retrieves a value from the cache or fetches it using the provided function. -// This implements request coalescing - multiple concurrent requests for the same key -// will share a single fetch operation. -// Returns the data, a boolean indicating if it was a cache hit, and any error. -func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, bool, error) { - // Check cache first - if data, ok := c.Get(key); ok { - return data, true, nil +// GetRange retrieves all values for a contiguous register range. +// Returns the per-register/coil values and true only if ALL are cached and fresh. +func (c *Cache) GetRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok || entry.IsExpired() { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// GetRangeStale retrieves all values for a contiguous register range, ignoring TTL. +// Returns the per-register/coil values and true only if ALL are present (even if expired). +func (c *Cache) GetRangeStale(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// SetRange stores individual values for a contiguous register range. +// All entries are stored with the same timestamp for consistency. +func (c *Cache) SetRange(slaveID byte, functionCode byte, startAddr uint16, values [][]byte) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for i, v := range values { + key := RegKey(slaveID, functionCode, startAddr+uint16(i)) + dataCopy := make([]byte, len(v)) + copy(dataCopy, v) + c.entries[key] = &Entry{ + Data: dataCopy, + Timestamp: now, + TTL: c.defaultTTL, + } } +} + +// DeleteRange removes all entries for a contiguous register range. +func (c *Cache) DeleteRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) { + c.mu.Lock() + defer c.mu.Unlock() - // Check if there's already an in-flight request + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + delete(c.entries, key) + } +} + +// Coalesce ensures only one fetch runs for a given key at a time. +// Other callers with the same key wait for and share the first caller's result. +// This handles request coalescing only — it does not interact with cache storage. +func (c *Cache) Coalesce(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, error) { c.inflightMu.Lock() if req, ok := c.inflight[key]; ok { c.inflightMu.Unlock() @@ -144,14 +210,14 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C select { case <-req.done: if req.err != nil { - return nil, false, req.err + return nil, req.err } // Return a copy data := make([]byte, len(req.result)) copy(data, req.result) - return data, false, nil + return data, nil case <-ctx.Done(): - return nil, false, ctx.Err() + return nil, ctx.Err() } } @@ -165,22 +231,23 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C // Fetch the data data, err := fetch(ctx) - // Store result + // Store result for waiters req.result = data req.err = err - // Cache successful results - if err == nil { - c.Set(key, data) - } - // Clean up and notify waiters c.inflightMu.Lock() delete(c.inflight, key) c.inflightMu.Unlock() close(req.done) - return data, false, err + if err != nil { + return nil, err + } + + result := make([]byte, len(data)) + copy(result, data) + return result, nil } // cleanup periodically removes expired entries. @@ -193,6 +260,9 @@ func (c *Cache) cleanup() { case <-c.done: return case <-ticker.C: + if c.keepStale { + continue + } c.mu.Lock() for key, entry := range c.entries { if entry.IsExpired() { diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 8741070..acbc00c 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -9,7 +9,7 @@ import ( ) func TestCache_GetSet(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() // Test miss @@ -29,7 +29,7 @@ func TestCache_GetSet(t *testing.T) { } func TestCache_TTL(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -49,7 +49,7 @@ func TestCache_TTL(t *testing.T) { } func TestCache_GetStale(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -68,7 +68,7 @@ func TestCache_GetStale(t *testing.T) { } func TestCache_Delete(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -79,16 +79,129 @@ func TestCache_Delete(t *testing.T) { } } -func TestCache_Key(t *testing.T) { - key := Key(1, 0x03, 100, 10) +func TestRegKey(t *testing.T) { + key := RegKey(1, 0x03, 100) + expected := "1:3:100" + if key != expected { + t.Errorf("expected %s, got %s", expected, key) + } +} + +func TestRangeKey(t *testing.T) { + key := RangeKey(1, 0x03, 100, 10) expected := "1:3:100:10" if key != expected { t.Errorf("expected %s, got %s", expected, key) } } -func TestCache_GetOrFetch(t *testing.T) { - c := New(time.Second) +func TestCache_GetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + // Store 3 registers + c.Set(RegKey(1, 0x03, 10), []byte{0x00, 0x01}) + c.Set(RegKey(1, 0x03, 11), []byte{0x00, 0x02}) + c.Set(RegKey(1, 0x03, 12), []byte{0x00, 0x03}) + + // Full range hit + values, ok := c.GetRange(1, 0x03, 10, 3) + if !ok { + t.Error("expected range hit") + } + if len(values) != 3 { + t.Fatalf("expected 3 values, got %d", len(values)) + } + for i, expected := range []byte{0x01, 0x02, 0x03} { + if values[i][1] != expected { + t.Errorf("value[%d]: expected 0x%02X, got 0x%02X", i, expected, values[i][1]) + } + } + + // Partial range miss + _, ok = c.GetRange(1, 0x03, 10, 5) + if ok { + t.Error("expected range miss (registers 13-14 not cached)") + } +} + +func TestCache_SetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}} + c.SetRange(1, 0x03, 100, values) + + // Each register should be independently accessible + data, ok := c.Get(RegKey(1, 0x03, 100)) + if !ok { + t.Error("expected hit for register 100") + } + if data[1] != 0x0A { + t.Errorf("expected 0x0A, got 0x%02X", data[1]) + } + + data, ok = c.Get(RegKey(1, 0x03, 101)) + if !ok { + t.Error("expected hit for register 101") + } + if data[1] != 0x0B { + t.Errorf("expected 0x0B, got 0x%02X", data[1]) + } +} + +func TestCache_DeleteRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}, {0x00, 0x0C}} + c.SetRange(1, 0x03, 100, values) + + // Delete middle register + c.DeleteRange(1, 0x03, 101, 1) + + // Register 100 still cached + if _, ok := c.Get(RegKey(1, 0x03, 100)); !ok { + t.Error("register 100 should still be cached") + } + // Register 101 deleted + if _, ok := c.Get(RegKey(1, 0x03, 101)); ok { + t.Error("register 101 should be deleted") + } + // Register 102 still cached + if _, ok := c.Get(RegKey(1, 0x03, 102)); !ok { + t.Error("register 102 should still be cached") + } + // Full range now misses + if _, ok := c.GetRange(1, 0x03, 100, 3); ok { + t.Error("expected range miss after deleting register 101") + } +} + +func TestCache_GetRangeStale(t *testing.T) { + c := New(50*time.Millisecond, false) + defer c.Close() + + c.SetRange(1, 0x03, 10, [][]byte{{0x00, 0x01}, {0x00, 0x02}}) + time.Sleep(100 * time.Millisecond) + + // Fresh get should miss + if _, ok := c.GetRange(1, 0x03, 10, 2); ok { + t.Error("expected range miss after TTL") + } + + // Stale get should succeed + values, ok := c.GetRangeStale(1, 0x03, 10, 2) + if !ok { + t.Error("expected stale range hit") + } + if len(values) != 2 { + t.Fatalf("expected 2 stale values, got %d", len(values)) + } +} + +func TestCache_Coalesce(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -98,39 +211,20 @@ func TestCache_GetOrFetch(t *testing.T) { return []byte("fetched"), nil } - // First call should fetch (cache miss) - data, hit, err := c.GetOrFetch(ctx, "key1", fetch) + data, err := c.Coalesce(ctx, "key1", fetch) if err != nil { t.Errorf("unexpected error: %v", err) } - if hit { - t.Error("expected cache miss on first call") - } if string(data) != "fetched" { t.Errorf("expected fetched, got %s", string(data)) } if fetchCount != 1 { t.Errorf("expected 1 fetch, got %d", fetchCount) } - - // Second call should hit cache - data, hit, err = c.GetOrFetch(ctx, "key1", fetch) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !hit { - t.Error("expected cache hit on second call") - } - if string(data) != "fetched" { - t.Errorf("expected fetched, got %s", string(data)) - } - if fetchCount != 1 { - t.Errorf("expected 1 fetch (cache hit), got %d", fetchCount) - } } -func TestCache_RequestCoalescing(t *testing.T) { - c := New(time.Second) +func TestCache_CoalescingConcurrent(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -153,7 +247,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[0], _, errors[0] = c.GetOrFetch(ctx, "key1", fetch) + results[0], errors[0] = c.Coalesce(ctx, "key1", fetch) }() // Wait for fetch to start @@ -165,7 +259,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[i], _, errors[i] = c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + results[i], errors[i] = c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { atomic.AddInt32(&fetchCount, 1) return []byte("should not be called"), nil }) @@ -196,7 +290,7 @@ func TestCache_RequestCoalescing(t *testing.T) { } func TestCache_ContextCancellation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -204,7 +298,7 @@ func TestCache_ContextCancellation(t *testing.T) { // Start a slow fetch go func() { - c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { close(fetchStarted) time.Sleep(time.Second) return []byte("fetched"), nil @@ -217,7 +311,7 @@ func TestCache_ContextCancellation(t *testing.T) { ctx2, cancel2 := context.WithCancel(context.Background()) cancel2() // Cancel immediately - _, _, err := c.GetOrFetch(ctx2, "key1", func(ctx context.Context) ([]byte, error) { + _, err := c.Coalesce(ctx2, "key1", func(ctx context.Context) ([]byte, error) { return []byte("should not be called"), nil }) @@ -229,7 +323,7 @@ func TestCache_ContextCancellation(t *testing.T) { } func TestCache_DataIsolation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() original := []byte("original") @@ -253,3 +347,59 @@ func TestCache_DataIsolation(t *testing.T) { t.Error("cache data was mutated via returned slice") } } + +func TestCache_RangeDataIsolation(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + original := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + c.SetRange(1, 0x03, 0, original) + + // Mutate original + original[0][0] = 0xFF + + // Cache should be unaffected + values, ok := c.GetRange(1, 0x03, 0, 2) + if !ok { + t.Error("expected range hit") + } + if values[0][0] != 0x00 { + t.Error("cache data was mutated via original slice") + } +} + +func TestCache_KeepStale(t *testing.T) { + // With keepStale=false, cleanup removes expired entries + c := New(50*time.Millisecond, false) + c.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + // Simulate cleanup + c.mu.Lock() + for key, entry := range c.entries { + if entry.IsExpired() { + delete(c.entries, key) + } + } + c.mu.Unlock() + + if _, ok := c.GetStale("key1"); ok { + t.Error("expected stale data to be gone after cleanup with keepStale=false") + } + c.Close() + + // With keepStale=true, expired entries survive cleanup + c2 := New(50*time.Millisecond, true) + c2.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + // Entry should still be accessible via GetStale + data, ok := c2.GetStale("key1") + if !ok { + t.Error("expected stale data to survive with keepStale=true") + } + if string(data) != "value1" { + t.Errorf("expected value1, got %s", string(data)) + } + c2.Close() +} From 8f55cc90246569c74835b2888cbcb773ffd003e5 Mon Sep 17 00:00:00 2001 From: Thomas Maurer Date: Thu, 2 Apr 2026 20:52:10 +0200 Subject: [PATCH 2/6] refactor: update proxy for per-register cache Decompose upstream responses into per-register cache entries and reassemble from cache on hits. Write invalidation now correctly removes individual registers in the written range, fixing stale data when writes overlap with larger cached read ranges. New helpers: - decomposeResponse: extract per-register values from Modbus PDU - assembleResponse: reconstruct Modbus PDU from cached values - Roundtrip tests for all function codes (registers + coils) - Tests verifying write invalidation of overlapping reads --- internal/proxy/proxy.go | 120 ++++++++++++++--- internal/proxy/proxy_test.go | 244 +++++++++++++++++++++++++++++++++-- 2 files changed, 338 insertions(+), 26 deletions(-) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 2589be1..aebdb68 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -28,7 +28,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Proxy, error) { cfg: cfg, logger: logger, client: modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger), - cache: cache.New(cfg.CacheTTL), + cache: cache.New(cfg.CacheTTL, cfg.CacheServeStale), } p.server = modbus.NewServer(p, logger) @@ -109,10 +109,21 @@ func (p *Proxy) HandleRequest(ctx context.Context, req *modbus.Request) ([]byte, } func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, error) { - key := cache.Key(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + // Check per-register cache + values, cacheHit := p.cache.GetRange(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + if cacheHit { + p.logger.Debug("cache hit", + "slave_id", req.SlaveID, + "func", fmt.Sprintf("0x%02X", req.FunctionCode), + "addr", req.Address, + "qty", req.Quantity, + ) + return assembleResponse(req.FunctionCode, req.Quantity, values), nil + } - // Use GetOrFetch for request coalescing - data, cacheHit, err := p.cache.GetOrFetch(ctx, key, func(ctx context.Context) ([]byte, error) { + // Cache miss — fetch with coalescing + rangeKey := cache.RangeKey(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + data, err := p.cache.Coalesce(ctx, rangeKey, func(ctx context.Context) ([]byte, error) { p.logger.Debug("cache miss", "slave_id", req.SlaveID, "func", fmt.Sprintf("0x%02X", req.FunctionCode), @@ -126,24 +137,21 @@ func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, er if err != nil { // Try serving stale data if configured if p.cfg.CacheServeStale { - if stale, ok := p.cache.GetStale(key); ok { + if staleValues, ok := p.cache.GetRangeStale(req.SlaveID, req.FunctionCode, req.Address, req.Quantity); ok { p.logger.Warn("upstream error, serving stale", "slave_id", req.SlaveID, "error", err, ) - return stale, nil + return assembleResponse(req.FunctionCode, req.Quantity, staleValues), nil } } return nil, err } - if cacheHit { - p.logger.Debug("cache hit", - "slave_id", req.SlaveID, - "func", fmt.Sprintf("0x%02X", req.FunctionCode), - "addr", req.Address, - "qty", req.Quantity, - ) + // Decompose response and store per-register + regValues := decomposeResponse(req.FunctionCode, req.Quantity, data) + if regValues != nil { + p.cache.SetRange(req.SlaveID, req.FunctionCode, req.Address, regValues) } return data, nil @@ -176,7 +184,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e return nil, err } - // Invalidate exact matching cache entries for all read function codes + // Invalidate per-register cache entries for the written range p.invalidateCache(req) return resp, nil @@ -186,7 +194,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e } func (p *Proxy) invalidateCache(req *modbus.Request) { - // Invalidate exact matches for all read function codes that could overlap + // Invalidate per-register entries for all read function codes readFuncs := []byte{ modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs, @@ -195,9 +203,87 @@ func (p *Proxy) invalidateCache(req *modbus.Request) { } for _, fc := range readFuncs { - key := cache.Key(req.SlaveID, fc, req.Address, req.Quantity) - p.cache.Delete(key) + p.cache.DeleteRange(req.SlaveID, fc, req.Address, req.Quantity) + } +} + +// decomposeResponse extracts per-register/coil values from a Modbus read response. +// Response format: [funcCode, byteCount, data...] +// For registers (FC 0x03, 0x04): each register is 2 bytes. +// For coils/discrete inputs (FC 0x01, 0x02): each coil is 1 bit, stored as 1 byte (0 or 1). +func decomposeResponse(functionCode byte, quantity uint16, data []byte) [][]byte { + if len(data) < 2 { + return nil + } + + payload := data[2:] // Skip funcCode and byteCount + + switch functionCode { + case modbus.FuncReadHoldingRegisters, modbus.FuncReadInputRegisters: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + offset := i * 2 + if int(offset+2) > len(payload) { + return nil + } + reg := make([]byte, 2) + copy(reg, payload[offset:offset+2]) + values[i] = reg + } + return values + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + byteIdx := i / 8 + bitIdx := i % 8 + if int(byteIdx) >= len(payload) { + return nil + } + if payload[byteIdx]&(1<= 2 { + resp[2+i*2] = v[0] + resp[2+i*2+1] = v[1] + } + } + return resp + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + byteCount := (quantity + 7) / 8 + resp := make([]byte, 2+byteCount) + resp[0] = functionCode + resp[1] = byte(byteCount) + for i, v := range values { + if len(v) > 0 && v[0] != 0 { + byteIdx := i / 8 + bitIdx := uint(i % 8) + resp[2+byteIdx] |= 1 << bitIdx + } + } + return resp + } + + return nil } func (p *Proxy) buildFakeWriteResponse(req *modbus.Request) []byte { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 99364cc..f36d764 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -26,7 +26,8 @@ func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, func TestProxy_HandleReadCacheHit(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - c := cache.New(time.Second) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ @@ -38,15 +39,17 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { cache: c, } - // Pre-populate cache - key := cache.Key(1, modbus.FuncReadHoldingRegisters, 0, 10) - c.Set(key, []byte{0x03, 0x14, 0x00, 0x01}) // Function code + byte count + data + // Pre-populate cache with per-register values + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, [][]byte{ + {0x00, 0x01}, + {0x00, 0x02}, + }) req := &modbus.Request{ SlaveID: 1, FunctionCode: modbus.FuncReadHoldingRegisters, Address: 0, - Quantity: 10, + Quantity: 2, } resp, err := p.HandleRequest(context.Background(), req) @@ -54,8 +57,15 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if string(resp) != string([]byte{0x03, 0x14, 0x00, 0x01}) { - t.Errorf("unexpected response: %v", resp) + // Expected assembled response: funcCode + byteCount + reg0 + reg1 + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } } } @@ -73,12 +83,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := cache.New(time.Second, false) + defer c.Close() + p := &Proxy{ cfg: &config.Config{ ReadOnly: tt.mode, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -104,13 +117,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { func TestProxy_HandleUnknownFunction(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ ReadOnly: config.ReadOnlyOn, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -183,3 +198,214 @@ func TestProxy_BuildFakeWriteResponse(t *testing.T) { }) } } + +func TestDecomposeResponse_Registers(t *testing.T) { + // Response: FC 0x03, byteCount=4, reg0=0x0001, reg1=0x0002 + data := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + values := decomposeResponse(modbus.FuncReadHoldingRegisters, 2, data) + + if len(values) != 2 { + t.Fatalf("expected 2 values, got %d", len(values)) + } + if values[0][0] != 0x00 || values[0][1] != 0x01 { + t.Errorf("reg0: expected 0x0001, got 0x%02X%02X", values[0][0], values[0][1]) + } + if values[1][0] != 0x00 || values[1][1] != 0x02 { + t.Errorf("reg1: expected 0x0002, got 0x%02X%02X", values[1][0], values[1][1]) + } +} + +func TestDecomposeResponse_Coils(t *testing.T) { + // Response: FC 0x01, byteCount=2, coils 0-9 + // 0xCD = 1100_1101: coils 0,2,3,6,7 on + // 0x01 = 0000_0001: coil 8 on + data := []byte{0x01, 0x02, 0xCD, 0x01} + values := decomposeResponse(modbus.FuncReadCoils, 10, data) + + if len(values) != 10 { + t.Fatalf("expected 10 values, got %d", len(values)) + } + + expected := []byte{1, 0, 1, 1, 0, 0, 1, 1, 1, 0} + for i, exp := range expected { + if values[i][0] != exp { + t.Errorf("coil %d: expected %d, got %d", i, exp, values[i][0]) + } + } +} + +func TestAssembleResponse_Registers(t *testing.T) { + values := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + resp := assembleResponse(modbus.FuncReadHoldingRegisters, 2, values) + + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestAssembleResponse_Coils(t *testing.T) { + // Coils 0,2,3,6,7 on, 8 on — should produce 0xCD 0x01 + values := [][]byte{{1}, {0}, {1}, {1}, {0}, {0}, {1}, {1}, {1}, {0}} + resp := assembleResponse(modbus.FuncReadCoils, 10, values) + + expected := []byte{0x01, 0x02, 0xCD, 0x01} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestDecomposeAssemble_Roundtrip(t *testing.T) { + tests := []struct { + name string + funcCode byte + quantity uint16 + data []byte + }{ + { + name: "holding registers", + funcCode: modbus.FuncReadHoldingRegisters, + quantity: 3, + data: []byte{0x03, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, + }, + { + name: "input registers", + funcCode: modbus.FuncReadInputRegisters, + quantity: 2, + data: []byte{0x04, 0x04, 0xFF, 0xFF, 0x00, 0x00}, + }, + { + name: "coils", + funcCode: modbus.FuncReadCoils, + quantity: 10, + data: []byte{0x01, 0x02, 0xCD, 0x01}, + }, + { + name: "discrete inputs", + funcCode: modbus.FuncReadDiscreteInputs, + quantity: 8, + data: []byte{0x02, 0x01, 0xAC}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := decomposeResponse(tt.funcCode, tt.quantity, tt.data) + if values == nil { + t.Fatal("decomposeResponse returned nil") + } + + reassembled := assembleResponse(tt.funcCode, tt.quantity, values) + if len(reassembled) != len(tt.data) { + t.Fatalf("length mismatch: expected %d, got %d", len(tt.data), len(reassembled)) + } + for i := range tt.data { + if reassembled[i] != tt.data[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, tt.data[i], reassembled[i]) + } + } + }) + } +} + +func TestProxy_WriteInvalidatesOverlappingReads(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 (simulating a previous read of range 0-9) + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to register 5 — should invalidate register 5 + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteSingleRegister, + Address: 5, + Quantity: 1, + }) + + // Full range 0-9 should now miss (register 5 is gone) + _, ok := c.GetRange(1, modbus.FuncReadHoldingRegisters, 0, 10) + if ok { + t.Error("expected range miss after write invalidation of register 5") + } + + // Registers 0-4 and 6-9 should still be cached individually + for i := uint16(0); i < 10; i++ { + _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)) + if i == 5 { + if ok { + t.Error("register 5 should be invalidated") + } + } else { + if !ok { + t.Errorf("register %d should still be cached", i) + } + } + } +} + +func TestProxy_WriteInvalidatesMultipleRegisters(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to registers 3-5 (write multiple) + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteMultipleRegs, + Address: 3, + Quantity: 3, + }) + + // Registers 3,4,5 should be gone + for i := uint16(3); i <= 5; i++ { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); ok { + t.Errorf("register %d should be invalidated", i) + } + } + + // Registers 0,1,2,6,7,8,9 should still be cached + for _, i := range []uint16{0, 1, 2, 6, 7, 8, 9} { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); !ok { + t.Errorf("register %d should still be cached", i) + } + } +} From 43f09a26322a19e51ba8cb3800fc9c5397308c17 Mon Sep 17 00:00:00 2001 From: Thomas Maurer Date: Thu, 2 Apr 2026 21:10:03 +0200 Subject: [PATCH 3/6] fix: log cache miss for all callers, not just the fetcher Move the cache miss log before Coalesce so coalesced waiters also get a log entry. The upstream client already logs request completion with duration, so fetches vs coalesced waits are distinguishable. --- internal/proxy/proxy.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index aebdb68..4d3c26d 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -122,15 +122,15 @@ func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, er } // Cache miss — fetch with coalescing + p.logger.Debug("cache miss", + "slave_id", req.SlaveID, + "func", fmt.Sprintf("0x%02X", req.FunctionCode), + "addr", req.Address, + "qty", req.Quantity, + ) + rangeKey := cache.RangeKey(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) data, err := p.cache.Coalesce(ctx, rangeKey, func(ctx context.Context) ([]byte, error) { - p.logger.Debug("cache miss", - "slave_id", req.SlaveID, - "func", fmt.Sprintf("0x%02X", req.FunctionCode), - "addr", req.Address, - "qty", req.Quantity, - ) - return p.client.Execute(ctx, req) }) From 4ede7e0d0a5d2d0f542959a19d761d1641600964 Mon Sep 17 00:00:00 2001 From: Thomas Maurer Date: Thu, 2 Apr 2026 21:35:59 +0200 Subject: [PATCH 4/6] fix: address copilot review comments - Guard GetRange/GetRangeStale against quantity=0 (false cache hit) - Use shared coilOn/coilOff slices in decomposeResponse to reduce per-coil allocations - Extract cleanupOnce so tests exercise the real keepStale guard instead of manually simulating cleanup --- internal/cache/cache.go | 34 ++++++++++++++++++++++++---------- internal/cache/cache_test.go | 17 ++++++----------- internal/proxy/proxy.go | 10 ++++++++-- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 5ff1597..b9126ff 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -132,6 +132,10 @@ func (c *Cache) Delete(key string) { // GetRange retrieves all values for a contiguous register range. // Returns the per-register/coil values and true only if ALL are cached and fresh. func (c *Cache) GetRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + c.mu.RLock() defer c.mu.RUnlock() @@ -152,6 +156,10 @@ func (c *Cache) GetRange(slaveID byte, functionCode byte, startAddr uint16, quan // GetRangeStale retrieves all values for a contiguous register range, ignoring TTL. // Returns the per-register/coil values and true only if ALL are present (even if expired). func (c *Cache) GetRangeStale(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + c.mu.RLock() defer c.mu.RUnlock() @@ -250,6 +258,21 @@ func (c *Cache) Coalesce(ctx context.Context, key string, fetch func(context.Con return result, nil } +// cleanupOnce runs a single cleanup pass, removing expired entries. +// Skips deletion when keepStale is true. +func (c *Cache) cleanupOnce() { + if c.keepStale { + return + } + c.mu.Lock() + for key, entry := range c.entries { + if entry.IsExpired() { + delete(c.entries, key) + } + } + c.mu.Unlock() +} + // cleanup periodically removes expired entries. func (c *Cache) cleanup() { ticker := time.NewTicker(time.Minute) @@ -260,16 +283,7 @@ func (c *Cache) cleanup() { case <-c.done: return case <-ticker.C: - if c.keepStale { - continue - } - c.mu.Lock() - for key, entry := range c.entries { - if entry.IsExpired() { - delete(c.entries, key) - } - } - c.mu.Unlock() + c.cleanupOnce() } } } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index acbc00c..445c010 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -374,29 +374,24 @@ func TestCache_KeepStale(t *testing.T) { c.Set("key1", []byte("value1")) time.Sleep(100 * time.Millisecond) - // Simulate cleanup - c.mu.Lock() - for key, entry := range c.entries { - if entry.IsExpired() { - delete(c.entries, key) - } - } - c.mu.Unlock() + c.cleanupOnce() if _, ok := c.GetStale("key1"); ok { t.Error("expected stale data to be gone after cleanup with keepStale=false") } c.Close() - // With keepStale=true, expired entries survive cleanup + // With keepStale=true, cleanup skips deletion c2 := New(50*time.Millisecond, true) c2.Set("key1", []byte("value1")) time.Sleep(100 * time.Millisecond) - // Entry should still be accessible via GetStale + c2.cleanupOnce() + + // Entry should still be accessible via GetStale after cleanup data, ok := c2.GetStale("key1") if !ok { - t.Error("expected stale data to survive with keepStale=true") + t.Error("expected stale data to survive cleanup with keepStale=true") } if string(data) != "value1" { t.Errorf("expected value1, got %s", string(data)) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4d3c26d..ffcaf73 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -207,6 +207,12 @@ func (p *Proxy) invalidateCache(req *modbus.Request) { } } +// Shared byte slices for coil values — safe to reuse since SetRange copies. +var ( + coilOn = []byte{1} + coilOff = []byte{0} +) + // decomposeResponse extracts per-register/coil values from a Modbus read response. // Response format: [funcCode, byteCount, data...] // For registers (FC 0x03, 0x04): each register is 2 bytes. @@ -241,9 +247,9 @@ func decomposeResponse(functionCode byte, quantity uint16, data []byte) [][]byte return nil } if payload[byteIdx]&(1< Date: Sat, 2 May 2026 00:31:14 +0200 Subject: [PATCH 5/6] test: cover proxy cache miss and stale fallback Update README and SPEC for the per-register cache design, then add proxy-level tests for miss -> fetch -> store -> hit and stale fallback on upstream errors. Introduce a small upstream client interface so tests can use a mock upstream without a real Modbus connection. --- README.md | 10 ++-- SPEC.md | 95 +++++++++++++++++++++-------- internal/cache/cache_test.go | 12 ++++ internal/proxy/proxy.go | 8 ++- internal/proxy/proxy_test.go | 113 +++++++++++++++++++++++++++++++++++ 5 files changed, 207 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 96533fe..043da77 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,12 @@ docker run --rm -v $(pwd):/app -w /app golang:1.24 go test ./... ## Cache Behavior -- **Key format**: `{slave_id}:{function_code}:{start_address}:{quantity}` -- **Read requests**: Served from cache if available and not expired -- **Write requests**: Forwarded to upstream (if allowed), exact matching cache entries invalidated -- **Request coalescing**: Multiple identical requests during a cache miss share a single upstream fetch +- **Key format**: values are cached per register/coil as `{slave_id}:{function_code}:{address}` +- **Read requests**: Served from cache only if every register/coil in the requested range is present and not expired +- **Cache misses**: If any value in the requested range is missing or expired, the full range is fetched from upstream and decomposed into per-register/coil cache entries +- **Write requests**: Forwarded to upstream (if allowed), then invalidate the written address range so overlapping cached reads cannot return stale values +- **Request coalescing**: Multiple identical range requests during a cache miss share a single upstream fetch using `{slave_id}:{function_code}:{start_address}:{quantity}` as the coalescing key +- **Stale fallback**: If enabled, expired entries are retained and can be served when upstream requests fail ## License diff --git a/SPEC.md b/SPEC.md index aa61fd1..dc6b7bf 100644 --- a/SPEC.md +++ b/SPEC.md @@ -49,6 +49,13 @@ Many Modbus devices (inverters, meters, battery systems) have limited polling ca ### 3. In-Memory Cache #### Cache Key Structure + +Values are cached per register/coil: +``` +{slave_id}:{function_code}:{address} +``` + +Request coalescing still uses the requested range as its key: ``` {slave_id}:{function_code}:{start_address}:{quantity} ``` @@ -56,21 +63,22 @@ Many Modbus devices (inverters, meters, battery systems) have limited polling ca #### Cache Entry ```go type CacheEntry struct { - Data []byte + Data []byte // one register (2 bytes) or one coil/input bit (1 byte: 0 or 1) Timestamp time.Time TTL time.Duration } ``` #### Cache Behavior -- **Read Operations**: Check cache first, return if valid (not expired) -- **Write Operations**: Always forward to device, invalidate exact matching cache entries (same slave_id, function_code, start_address, quantity) +- **Read Operations**: Check the per-register/coil cache first. Return from cache only if every value in the requested range is present and not expired. +- **Cache Misses**: If any value in the requested range is missing or expired, fetch the full requested range from upstream, then decompose the response into per-register/coil cache entries. +- **Write Operations**: Always forward to the device when writes are allowed, then invalidate each cached register/coil in the written address range. This prevents overlapping cached read ranges from serving stale values after frequent writes. - **TTL**: Configurable (default: 10 seconds) -- **Cleanup**: Time-based expiration (entries removed when TTL expires) -- **Staleness**: Option to serve stale data on upstream failure (default: off) +- **Cleanup**: Time-based expiration. Expired entries are removed during cleanup unless stale serving is enabled. +- **Staleness**: Option to serve stale data on upstream failure (default: off). When enabled, expired entries are retained so they remain available for fallback. ### Request Coalescing -- Identical in-flight requests are coalesced (same slave_id, function, address, quantity) +- Identical in-flight range requests are coalesced (same slave_id, function, address, quantity) - Second request arriving while first is pending will wait for and share the first's response - Prevents thundering herd on cache miss @@ -144,47 +152,82 @@ type CachingHandler struct { ```go type Cache struct { - mu sync.RWMutex - entries map[string]*CacheEntry - ttl time.Duration // default: 10 * time.Second + mu sync.RWMutex + entries map[string]*CacheEntry + defaultTTL time.Duration + keepStale bool + + // Request coalescing for identical range requests. + inflight map[string]*inflightRequest + inflightMu sync.Mutex +} + +func RegKey(slaveID, functionCode byte, address uint16) string { + return fmt.Sprintf("%d:%d:%d", slaveID, functionCode, address) } -func (c *Cache) Get(key string) ([]byte, bool) { +func RangeKey(slaveID, functionCode byte, address, quantity uint16) string { + return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, address, quantity) +} + +func (c *Cache) GetRange(slaveID, functionCode byte, address, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + c.mu.RLock() defer c.mu.RUnlock() - - entry, ok := c.entries[key] - if !ok || time.Since(entry.Timestamp) > entry.TTL { - return nil, false + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + entry, ok := c.entries[RegKey(slaveID, functionCode, address+i)] + if !ok || entry.IsExpired() { + return nil, false + } + values[i] = append([]byte(nil), entry.Data...) + } + return values, true +} + +func (c *Cache) SetRange(slaveID, functionCode byte, address uint16, values [][]byte) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for i, value := range values { + c.entries[RegKey(slaveID, functionCode, address+uint16(i))] = &CacheEntry{ + Data: append([]byte(nil), value...), + Timestamp: now, + TTL: c.defaultTTL, + } } - return entry.Data, true } -func (c *Cache) Set(key string, data []byte, ttl time.Duration) { +func (c *Cache) DeleteRange(slaveID, functionCode byte, address, quantity uint16) { c.mu.Lock() defer c.mu.Unlock() - - c.entries[key] = &CacheEntry{ - Data: data, - Timestamp: time.Now(), - TTL: ttl, + + for i := uint16(0); i < quantity; i++ { + delete(c.entries, RegKey(slaveID, functionCode, address+i)) } } ``` +The cache also exposes `Coalesce(ctx, rangeKey, fetch)` for request coalescing. It does not read or write cache entries directly; the proxy performs cache lookups and stores decomposed responses. + ### Request Flow 1. Client sends Modbus TCP request 2. Parse request: extract slave ID, function code, address, quantity 3. **For reads**: - - Build cache key - - Check cache → if hit & valid, return cached data - - On miss: forward to upstream device - - Store response in cache + - Check every per-register/coil cache key in the requested range + - If all values are present and valid, reassemble and return the Modbus response + - On any miss or expired value: coalesce identical in-flight range requests, then forward to upstream device + - Decompose successful upstream responses into per-register/coil cache entries - Return response to client 4. **For writes**: - Check readonly mode - - If allowed: forward to upstream, optionally invalidate cache + - If allowed: forward to upstream, then invalidate every cached register/coil in the written address range - Return response ## Logging diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 445c010..2099ad8 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -125,6 +125,18 @@ func TestCache_GetRange(t *testing.T) { } } +func TestCache_GetRangeZeroQuantityMiss(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + if _, ok := c.GetRange(1, 0x03, 10, 0); ok { + t.Error("expected zero-quantity range to miss") + } + if _, ok := c.GetRangeStale(1, 0x03, 10, 0); ok { + t.Error("expected zero-quantity stale range to miss") + } +} + func TestCache_SetRange(t *testing.T) { c := New(time.Second, false) defer c.Close() diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index ffcaf73..2e09228 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -13,12 +13,18 @@ import ( "github.com/tma/mbproxy/internal/modbus" ) +type upstreamClient interface { + Connect() error + Close() error + Execute(context.Context, *modbus.Request) ([]byte, error) +} + // Proxy is a caching Modbus proxy server. type Proxy struct { cfg *config.Config logger *slog.Logger server *modbus.Server - client *modbus.Client + client upstreamClient cache *cache.Cache } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index f36d764..5a689ff 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1,7 +1,9 @@ package proxy import ( + "bytes" "context" + "errors" "io" "log/slog" "testing" @@ -19,6 +21,10 @@ type mockClient struct { calls int } +func (m *mockClient) Connect() error { return nil } + +func (m *mockClient) Close() error { return nil } + func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, error) { m.calls++ return m.response, m.err @@ -69,6 +75,113 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { } } +func TestProxy_HandleReadMissFetchesAndCaches(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + upstream := &mockClient{ + response: []byte{0x03, 0x04, 0x00, 0x0A, 0x00, 0x0B}, + } + p := &Proxy{ + cfg: &config.Config{ + CacheTTL: time.Second, + CacheServeStale: false, + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + client: upstream, + cache: c, + } + + req := &modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncReadHoldingRegisters, + Address: 10, + Quantity: 2, + } + + resp, err := p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []byte{0x03, 0x04, 0x00, 0x0A, 0x00, 0x0B} + if !bytes.Equal(resp, expected) { + t.Fatalf("first response: expected %v, got %v", expected, resp) + } + if upstream.calls != 1 { + t.Fatalf("expected 1 upstream call after miss, got %d", upstream.calls) + } + + values, ok := c.GetRange(1, modbus.FuncReadHoldingRegisters, 10, 2) + if !ok { + t.Fatal("expected fetched response to be cached per register") + } + if !bytes.Equal(values[0], []byte{0x00, 0x0A}) || !bytes.Equal(values[1], []byte{0x00, 0x0B}) { + t.Fatalf("unexpected cached values: %v", values) + } + + // Change the upstream response. The second request should be served from cache, + // so the upstream should not be called again and the response should stay the same. + upstream.response = []byte{0x03, 0x04, 0x00, 0xFF, 0x00, 0xFF} + resp, err = p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error on cached read: %v", err) + } + if !bytes.Equal(resp, expected) { + t.Fatalf("cached response: expected %v, got %v", expected, resp) + } + if upstream.calls != 1 { + t.Fatalf("expected cached read to avoid upstream call, got %d calls", upstream.calls) + } +} + +func TestProxy_HandleReadServesStaleOnUpstreamError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(10*time.Millisecond, true) + defer c.Close() + + c.SetRange(1, modbus.FuncReadHoldingRegisters, 20, [][]byte{ + {0x00, 0x01}, + {0x00, 0x02}, + }) + time.Sleep(20 * time.Millisecond) + + upstreamErr := errors.New("upstream unavailable") + upstream := &mockClient{err: upstreamErr} + p := &Proxy{ + cfg: &config.Config{ + CacheTTL: 10 * time.Millisecond, + CacheServeStale: true, + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + client: upstream, + cache: c, + } + + req := &modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncReadHoldingRegisters, + Address: 20, + Quantity: 2, + } + + resp, err := p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("expected stale response, got error: %v", err) + } + if upstream.calls != 1 { + t.Fatalf("expected one failed upstream call before serving stale, got %d", upstream.calls) + } + + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if !bytes.Equal(resp, expected) { + t.Fatalf("stale response: expected %v, got %v", expected, resp) + } +} + func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) From ac66043ae96feda79561d15e876ba848c669554e Mon Sep 17 00:00:00 2001 From: Thomas Maurer Date: Sat, 2 May 2026 00:34:16 +0200 Subject: [PATCH 6/6] fix: include health check in proxy test client interface Main added Proxy.Healthy, which delegates to the upstream client. The local mockable interface needs to include Healthy so PR merge builds compile against the current base branch. --- internal/proxy/proxy.go | 1 + internal/proxy/proxy_test.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 2e09228..505eadd 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -16,6 +16,7 @@ import ( type upstreamClient interface { Connect() error Close() error + Healthy() error Execute(context.Context, *modbus.Request) ([]byte, error) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 5a689ff..e643aec 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -25,6 +25,8 @@ func (m *mockClient) Connect() error { return nil } func (m *mockClient) Close() error { return nil } +func (m *mockClient) Healthy() error { return nil } + func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, error) { m.calls++ return m.response, m.err