diff --git a/README.md b/README.md index 96533fe..043da77 100644 --- a/README.md +++ b/README.md @@ -134,10 +134,12 @@ docker run --rm -v $(pwd):/app -w /app golang:1.24 go test ./... ## Cache Behavior -- **Key format**: `{slave_id}:{function_code}:{start_address}:{quantity}` -- **Read requests**: Served from cache if available and not expired -- **Write requests**: Forwarded to upstream (if allowed), exact matching cache entries invalidated -- **Request coalescing**: Multiple identical requests during a cache miss share a single upstream fetch +- **Key format**: values are cached per register/coil as `{slave_id}:{function_code}:{address}` +- **Read requests**: Served from cache only if every register/coil in the requested range is present and not expired +- **Cache misses**: If any value in the requested range is missing or expired, the full range is fetched from upstream and decomposed into per-register/coil cache entries +- **Write requests**: Forwarded to upstream (if allowed), then invalidate the written address range so overlapping cached reads cannot return stale values +- **Request coalescing**: Multiple identical range requests during a cache miss share a single upstream fetch using `{slave_id}:{function_code}:{start_address}:{quantity}` as the coalescing key +- **Stale fallback**: If enabled, expired entries are retained and can be served when upstream requests fail ## License diff --git a/SPEC.md b/SPEC.md index aa61fd1..dc6b7bf 100644 --- a/SPEC.md +++ b/SPEC.md @@ -49,6 +49,13 @@ Many Modbus devices (inverters, meters, battery systems) have limited polling ca ### 3. In-Memory Cache #### Cache Key Structure + +Values are cached per register/coil: +``` +{slave_id}:{function_code}:{address} +``` + +Request coalescing still uses the requested range as its key: ``` {slave_id}:{function_code}:{start_address}:{quantity} ``` @@ -56,21 +63,22 @@ Many Modbus devices (inverters, meters, battery systems) have limited polling ca #### Cache Entry ```go type CacheEntry struct { - Data []byte + Data []byte // one register (2 bytes) or one coil/input bit (1 byte: 0 or 1) Timestamp time.Time TTL time.Duration } ``` #### Cache Behavior -- **Read Operations**: Check cache first, return if valid (not expired) -- **Write Operations**: Always forward to device, invalidate exact matching cache entries (same slave_id, function_code, start_address, quantity) +- **Read Operations**: Check the per-register/coil cache first. Return from cache only if every value in the requested range is present and not expired. +- **Cache Misses**: If any value in the requested range is missing or expired, fetch the full requested range from upstream, then decompose the response into per-register/coil cache entries. +- **Write Operations**: Always forward to the device when writes are allowed, then invalidate each cached register/coil in the written address range. This prevents overlapping cached read ranges from serving stale values after frequent writes. - **TTL**: Configurable (default: 10 seconds) -- **Cleanup**: Time-based expiration (entries removed when TTL expires) -- **Staleness**: Option to serve stale data on upstream failure (default: off) +- **Cleanup**: Time-based expiration. Expired entries are removed during cleanup unless stale serving is enabled. +- **Staleness**: Option to serve stale data on upstream failure (default: off). When enabled, expired entries are retained so they remain available for fallback. ### Request Coalescing -- Identical in-flight requests are coalesced (same slave_id, function, address, quantity) +- Identical in-flight range requests are coalesced (same slave_id, function, address, quantity) - Second request arriving while first is pending will wait for and share the first's response - Prevents thundering herd on cache miss @@ -144,47 +152,82 @@ type CachingHandler struct { ```go type Cache struct { - mu sync.RWMutex - entries map[string]*CacheEntry - ttl time.Duration // default: 10 * time.Second + mu sync.RWMutex + entries map[string]*CacheEntry + defaultTTL time.Duration + keepStale bool + + // Request coalescing for identical range requests. + inflight map[string]*inflightRequest + inflightMu sync.Mutex +} + +func RegKey(slaveID, functionCode byte, address uint16) string { + return fmt.Sprintf("%d:%d:%d", slaveID, functionCode, address) } -func (c *Cache) Get(key string) ([]byte, bool) { +func RangeKey(slaveID, functionCode byte, address, quantity uint16) string { + return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, address, quantity) +} + +func (c *Cache) GetRange(slaveID, functionCode byte, address, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + c.mu.RLock() defer c.mu.RUnlock() - - entry, ok := c.entries[key] - if !ok || time.Since(entry.Timestamp) > entry.TTL { - return nil, false + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + entry, ok := c.entries[RegKey(slaveID, functionCode, address+i)] + if !ok || entry.IsExpired() { + return nil, false + } + values[i] = append([]byte(nil), entry.Data...) + } + return values, true +} + +func (c *Cache) SetRange(slaveID, functionCode byte, address uint16, values [][]byte) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for i, value := range values { + c.entries[RegKey(slaveID, functionCode, address+uint16(i))] = &CacheEntry{ + Data: append([]byte(nil), value...), + Timestamp: now, + TTL: c.defaultTTL, + } } - return entry.Data, true } -func (c *Cache) Set(key string, data []byte, ttl time.Duration) { +func (c *Cache) DeleteRange(slaveID, functionCode byte, address, quantity uint16) { c.mu.Lock() defer c.mu.Unlock() - - c.entries[key] = &CacheEntry{ - Data: data, - Timestamp: time.Now(), - TTL: ttl, + + for i := uint16(0); i < quantity; i++ { + delete(c.entries, RegKey(slaveID, functionCode, address+i)) } } ``` +The cache also exposes `Coalesce(ctx, rangeKey, fetch)` for request coalescing. It does not read or write cache entries directly; the proxy performs cache lookups and stores decomposed responses. + ### Request Flow 1. Client sends Modbus TCP request 2. Parse request: extract slave ID, function code, address, quantity 3. **For reads**: - - Build cache key - - Check cache → if hit & valid, return cached data - - On miss: forward to upstream device - - Store response in cache + - Check every per-register/coil cache key in the requested range + - If all values are present and valid, reassemble and return the Modbus response + - On any miss or expired value: coalesce identical in-flight range requests, then forward to upstream device + - Decompose successful upstream responses into per-register/coil cache entries - Return response to client 4. **For writes**: - Check readonly mode - - If allowed: forward to upstream, optionally invalidate cache + - If allowed: forward to upstream, then invalidate every cached register/coil in the written address range - Return response ## Logging diff --git a/internal/cache/cache.go b/internal/cache/cache.go index f76648e..b9126ff 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -20,11 +20,12 @@ func (e *Entry) IsExpired() bool { return time.Since(e.Timestamp) > e.TTL } -// Cache is a thread-safe in-memory cache with TTL. +// Cache is a thread-safe in-memory cache with TTL and per-register storage. type Cache struct { mu sync.RWMutex entries map[string]*Entry defaultTTL time.Duration + keepStale bool // when true, cleanup won't delete expired entries // For request coalescing inflight map[string]*inflightRequest @@ -41,10 +42,12 @@ type inflightRequest struct { } // New creates a new cache with the specified default TTL. -func New(defaultTTL time.Duration) *Cache { +// If keepStale is true, expired entries are retained for stale serving. +func New(defaultTTL time.Duration, keepStale bool) *Cache { c := &Cache{ entries: make(map[string]*Entry), defaultTTL: defaultTTL, + keepStale: keepStale, inflight: make(map[string]*inflightRequest), done: make(chan struct{}), } @@ -60,9 +63,14 @@ func (c *Cache) Close() { close(c.done) } -// Key generates a cache key from request parameters. -func Key(slaveID byte, functionCode byte, address uint16, quantity uint16) string { - return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, address, quantity) +// RegKey generates a cache key for a single register or coil. +func RegKey(slaveID byte, functionCode byte, address uint16) string { + return fmt.Sprintf("%d:%d:%d", slaveID, functionCode, address) +} + +// RangeKey generates a coalescing key for a request range. +func RangeKey(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) string { + return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, startAddr, quantity) } // Get retrieves a value from the cache. @@ -100,11 +108,6 @@ func (c *Cache) GetStale(key string) ([]byte, bool) { // Set stores a value in the cache with the default TTL. func (c *Cache) Set(key string, data []byte) { - c.SetWithTTL(key, data, c.defaultTTL) -} - -// SetWithTTL stores a value in the cache with a specific TTL. -func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.mu.Lock() defer c.mu.Unlock() @@ -115,7 +118,7 @@ func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.entries[key] = &Entry{ Data: dataCopy, Timestamp: time.Now(), - TTL: ttl, + TTL: c.defaultTTL, } } @@ -126,17 +129,88 @@ func (c *Cache) Delete(key string) { delete(c.entries, key) } -// GetOrFetch retrieves a value from the cache or fetches it using the provided function. -// This implements request coalescing - multiple concurrent requests for the same key -// will share a single fetch operation. -// Returns the data, a boolean indicating if it was a cache hit, and any error. -func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, bool, error) { - // Check cache first - if data, ok := c.Get(key); ok { - return data, true, nil +// GetRange retrieves all values for a contiguous register range. +// Returns the per-register/coil values and true only if ALL are cached and fresh. +func (c *Cache) GetRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok || entry.IsExpired() { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// GetRangeStale retrieves all values for a contiguous register range, ignoring TTL. +// Returns the per-register/coil values and true only if ALL are present (even if expired). +func (c *Cache) GetRangeStale(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// SetRange stores individual values for a contiguous register range. +// All entries are stored with the same timestamp for consistency. +func (c *Cache) SetRange(slaveID byte, functionCode byte, startAddr uint16, values [][]byte) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for i, v := range values { + key := RegKey(slaveID, functionCode, startAddr+uint16(i)) + dataCopy := make([]byte, len(v)) + copy(dataCopy, v) + c.entries[key] = &Entry{ + Data: dataCopy, + Timestamp: now, + TTL: c.defaultTTL, + } + } +} + +// DeleteRange removes all entries for a contiguous register range. +func (c *Cache) DeleteRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) { + c.mu.Lock() + defer c.mu.Unlock() + + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + delete(c.entries, key) } +} - // Check if there's already an in-flight request +// Coalesce ensures only one fetch runs for a given key at a time. +// Other callers with the same key wait for and share the first caller's result. +// This handles request coalescing only — it does not interact with cache storage. +func (c *Cache) Coalesce(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, error) { c.inflightMu.Lock() if req, ok := c.inflight[key]; ok { c.inflightMu.Unlock() @@ -144,14 +218,14 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C select { case <-req.done: if req.err != nil { - return nil, false, req.err + return nil, req.err } // Return a copy data := make([]byte, len(req.result)) copy(data, req.result) - return data, false, nil + return data, nil case <-ctx.Done(): - return nil, false, ctx.Err() + return nil, ctx.Err() } } @@ -165,22 +239,38 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C // Fetch the data data, err := fetch(ctx) - // Store result + // Store result for waiters req.result = data req.err = err - // Cache successful results - if err == nil { - c.Set(key, data) - } - // Clean up and notify waiters c.inflightMu.Lock() delete(c.inflight, key) c.inflightMu.Unlock() close(req.done) - return data, false, err + if err != nil { + return nil, err + } + + result := make([]byte, len(data)) + copy(result, data) + return result, nil +} + +// cleanupOnce runs a single cleanup pass, removing expired entries. +// Skips deletion when keepStale is true. +func (c *Cache) cleanupOnce() { + if c.keepStale { + return + } + c.mu.Lock() + for key, entry := range c.entries { + if entry.IsExpired() { + delete(c.entries, key) + } + } + c.mu.Unlock() } // cleanup periodically removes expired entries. @@ -193,13 +283,7 @@ func (c *Cache) cleanup() { case <-c.done: return case <-ticker.C: - c.mu.Lock() - for key, entry := range c.entries { - if entry.IsExpired() { - delete(c.entries, key) - } - } - c.mu.Unlock() + c.cleanupOnce() } } } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 8741070..2099ad8 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -9,7 +9,7 @@ import ( ) func TestCache_GetSet(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() // Test miss @@ -29,7 +29,7 @@ func TestCache_GetSet(t *testing.T) { } func TestCache_TTL(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -49,7 +49,7 @@ func TestCache_TTL(t *testing.T) { } func TestCache_GetStale(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -68,7 +68,7 @@ func TestCache_GetStale(t *testing.T) { } func TestCache_Delete(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -79,16 +79,141 @@ func TestCache_Delete(t *testing.T) { } } -func TestCache_Key(t *testing.T) { - key := Key(1, 0x03, 100, 10) +func TestRegKey(t *testing.T) { + key := RegKey(1, 0x03, 100) + expected := "1:3:100" + if key != expected { + t.Errorf("expected %s, got %s", expected, key) + } +} + +func TestRangeKey(t *testing.T) { + key := RangeKey(1, 0x03, 100, 10) expected := "1:3:100:10" if key != expected { t.Errorf("expected %s, got %s", expected, key) } } -func TestCache_GetOrFetch(t *testing.T) { - c := New(time.Second) +func TestCache_GetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + // Store 3 registers + c.Set(RegKey(1, 0x03, 10), []byte{0x00, 0x01}) + c.Set(RegKey(1, 0x03, 11), []byte{0x00, 0x02}) + c.Set(RegKey(1, 0x03, 12), []byte{0x00, 0x03}) + + // Full range hit + values, ok := c.GetRange(1, 0x03, 10, 3) + if !ok { + t.Error("expected range hit") + } + if len(values) != 3 { + t.Fatalf("expected 3 values, got %d", len(values)) + } + for i, expected := range []byte{0x01, 0x02, 0x03} { + if values[i][1] != expected { + t.Errorf("value[%d]: expected 0x%02X, got 0x%02X", i, expected, values[i][1]) + } + } + + // Partial range miss + _, ok = c.GetRange(1, 0x03, 10, 5) + if ok { + t.Error("expected range miss (registers 13-14 not cached)") + } +} + +func TestCache_GetRangeZeroQuantityMiss(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + if _, ok := c.GetRange(1, 0x03, 10, 0); ok { + t.Error("expected zero-quantity range to miss") + } + if _, ok := c.GetRangeStale(1, 0x03, 10, 0); ok { + t.Error("expected zero-quantity stale range to miss") + } +} + +func TestCache_SetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}} + c.SetRange(1, 0x03, 100, values) + + // Each register should be independently accessible + data, ok := c.Get(RegKey(1, 0x03, 100)) + if !ok { + t.Error("expected hit for register 100") + } + if data[1] != 0x0A { + t.Errorf("expected 0x0A, got 0x%02X", data[1]) + } + + data, ok = c.Get(RegKey(1, 0x03, 101)) + if !ok { + t.Error("expected hit for register 101") + } + if data[1] != 0x0B { + t.Errorf("expected 0x0B, got 0x%02X", data[1]) + } +} + +func TestCache_DeleteRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}, {0x00, 0x0C}} + c.SetRange(1, 0x03, 100, values) + + // Delete middle register + c.DeleteRange(1, 0x03, 101, 1) + + // Register 100 still cached + if _, ok := c.Get(RegKey(1, 0x03, 100)); !ok { + t.Error("register 100 should still be cached") + } + // Register 101 deleted + if _, ok := c.Get(RegKey(1, 0x03, 101)); ok { + t.Error("register 101 should be deleted") + } + // Register 102 still cached + if _, ok := c.Get(RegKey(1, 0x03, 102)); !ok { + t.Error("register 102 should still be cached") + } + // Full range now misses + if _, ok := c.GetRange(1, 0x03, 100, 3); ok { + t.Error("expected range miss after deleting register 101") + } +} + +func TestCache_GetRangeStale(t *testing.T) { + c := New(50*time.Millisecond, false) + defer c.Close() + + c.SetRange(1, 0x03, 10, [][]byte{{0x00, 0x01}, {0x00, 0x02}}) + time.Sleep(100 * time.Millisecond) + + // Fresh get should miss + if _, ok := c.GetRange(1, 0x03, 10, 2); ok { + t.Error("expected range miss after TTL") + } + + // Stale get should succeed + values, ok := c.GetRangeStale(1, 0x03, 10, 2) + if !ok { + t.Error("expected stale range hit") + } + if len(values) != 2 { + t.Fatalf("expected 2 stale values, got %d", len(values)) + } +} + +func TestCache_Coalesce(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -98,39 +223,20 @@ func TestCache_GetOrFetch(t *testing.T) { return []byte("fetched"), nil } - // First call should fetch (cache miss) - data, hit, err := c.GetOrFetch(ctx, "key1", fetch) + data, err := c.Coalesce(ctx, "key1", fetch) if err != nil { t.Errorf("unexpected error: %v", err) } - if hit { - t.Error("expected cache miss on first call") - } if string(data) != "fetched" { t.Errorf("expected fetched, got %s", string(data)) } if fetchCount != 1 { t.Errorf("expected 1 fetch, got %d", fetchCount) } - - // Second call should hit cache - data, hit, err = c.GetOrFetch(ctx, "key1", fetch) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !hit { - t.Error("expected cache hit on second call") - } - if string(data) != "fetched" { - t.Errorf("expected fetched, got %s", string(data)) - } - if fetchCount != 1 { - t.Errorf("expected 1 fetch (cache hit), got %d", fetchCount) - } } -func TestCache_RequestCoalescing(t *testing.T) { - c := New(time.Second) +func TestCache_CoalescingConcurrent(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -153,7 +259,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[0], _, errors[0] = c.GetOrFetch(ctx, "key1", fetch) + results[0], errors[0] = c.Coalesce(ctx, "key1", fetch) }() // Wait for fetch to start @@ -165,7 +271,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[i], _, errors[i] = c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + results[i], errors[i] = c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { atomic.AddInt32(&fetchCount, 1) return []byte("should not be called"), nil }) @@ -196,7 +302,7 @@ func TestCache_RequestCoalescing(t *testing.T) { } func TestCache_ContextCancellation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -204,7 +310,7 @@ func TestCache_ContextCancellation(t *testing.T) { // Start a slow fetch go func() { - c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { close(fetchStarted) time.Sleep(time.Second) return []byte("fetched"), nil @@ -217,7 +323,7 @@ func TestCache_ContextCancellation(t *testing.T) { ctx2, cancel2 := context.WithCancel(context.Background()) cancel2() // Cancel immediately - _, _, err := c.GetOrFetch(ctx2, "key1", func(ctx context.Context) ([]byte, error) { + _, err := c.Coalesce(ctx2, "key1", func(ctx context.Context) ([]byte, error) { return []byte("should not be called"), nil }) @@ -229,7 +335,7 @@ func TestCache_ContextCancellation(t *testing.T) { } func TestCache_DataIsolation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() original := []byte("original") @@ -253,3 +359,54 @@ func TestCache_DataIsolation(t *testing.T) { t.Error("cache data was mutated via returned slice") } } + +func TestCache_RangeDataIsolation(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + original := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + c.SetRange(1, 0x03, 0, original) + + // Mutate original + original[0][0] = 0xFF + + // Cache should be unaffected + values, ok := c.GetRange(1, 0x03, 0, 2) + if !ok { + t.Error("expected range hit") + } + if values[0][0] != 0x00 { + t.Error("cache data was mutated via original slice") + } +} + +func TestCache_KeepStale(t *testing.T) { + // With keepStale=false, cleanup removes expired entries + c := New(50*time.Millisecond, false) + c.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + c.cleanupOnce() + + if _, ok := c.GetStale("key1"); ok { + t.Error("expected stale data to be gone after cleanup with keepStale=false") + } + c.Close() + + // With keepStale=true, cleanup skips deletion + c2 := New(50*time.Millisecond, true) + c2.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + c2.cleanupOnce() + + // Entry should still be accessible via GetStale after cleanup + data, ok := c2.GetStale("key1") + if !ok { + t.Error("expected stale data to survive cleanup with keepStale=true") + } + if string(data) != "value1" { + t.Errorf("expected value1, got %s", string(data)) + } + c2.Close() +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 2589be1..505eadd 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -13,12 +13,19 @@ import ( "github.com/tma/mbproxy/internal/modbus" ) +type upstreamClient interface { + Connect() error + Close() error + Healthy() error + Execute(context.Context, *modbus.Request) ([]byte, error) +} + // Proxy is a caching Modbus proxy server. type Proxy struct { cfg *config.Config logger *slog.Logger server *modbus.Server - client *modbus.Client + client upstreamClient cache *cache.Cache } @@ -28,7 +35,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Proxy, error) { cfg: cfg, logger: logger, client: modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger), - cache: cache.New(cfg.CacheTTL), + cache: cache.New(cfg.CacheTTL, cfg.CacheServeStale), } p.server = modbus.NewServer(p, logger) @@ -109,41 +116,49 @@ func (p *Proxy) HandleRequest(ctx context.Context, req *modbus.Request) ([]byte, } func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, error) { - key := cache.Key(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) - - // Use GetOrFetch for request coalescing - data, cacheHit, err := p.cache.GetOrFetch(ctx, key, func(ctx context.Context) ([]byte, error) { - p.logger.Debug("cache miss", + // Check per-register cache + values, cacheHit := p.cache.GetRange(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + if cacheHit { + p.logger.Debug("cache hit", "slave_id", req.SlaveID, "func", fmt.Sprintf("0x%02X", req.FunctionCode), "addr", req.Address, "qty", req.Quantity, ) + return assembleResponse(req.FunctionCode, req.Quantity, values), nil + } + // Cache miss — fetch with coalescing + p.logger.Debug("cache miss", + "slave_id", req.SlaveID, + "func", fmt.Sprintf("0x%02X", req.FunctionCode), + "addr", req.Address, + "qty", req.Quantity, + ) + + rangeKey := cache.RangeKey(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + data, err := p.cache.Coalesce(ctx, rangeKey, func(ctx context.Context) ([]byte, error) { return p.client.Execute(ctx, req) }) if err != nil { // Try serving stale data if configured if p.cfg.CacheServeStale { - if stale, ok := p.cache.GetStale(key); ok { + if staleValues, ok := p.cache.GetRangeStale(req.SlaveID, req.FunctionCode, req.Address, req.Quantity); ok { p.logger.Warn("upstream error, serving stale", "slave_id", req.SlaveID, "error", err, ) - return stale, nil + return assembleResponse(req.FunctionCode, req.Quantity, staleValues), nil } } return nil, err } - if cacheHit { - p.logger.Debug("cache hit", - "slave_id", req.SlaveID, - "func", fmt.Sprintf("0x%02X", req.FunctionCode), - "addr", req.Address, - "qty", req.Quantity, - ) + // Decompose response and store per-register + regValues := decomposeResponse(req.FunctionCode, req.Quantity, data) + if regValues != nil { + p.cache.SetRange(req.SlaveID, req.FunctionCode, req.Address, regValues) } return data, nil @@ -176,7 +191,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e return nil, err } - // Invalidate exact matching cache entries for all read function codes + // Invalidate per-register cache entries for the written range p.invalidateCache(req) return resp, nil @@ -186,7 +201,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e } func (p *Proxy) invalidateCache(req *modbus.Request) { - // Invalidate exact matches for all read function codes that could overlap + // Invalidate per-register entries for all read function codes readFuncs := []byte{ modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs, @@ -195,9 +210,93 @@ func (p *Proxy) invalidateCache(req *modbus.Request) { } for _, fc := range readFuncs { - key := cache.Key(req.SlaveID, fc, req.Address, req.Quantity) - p.cache.Delete(key) + p.cache.DeleteRange(req.SlaveID, fc, req.Address, req.Quantity) + } +} + +// Shared byte slices for coil values — safe to reuse since SetRange copies. +var ( + coilOn = []byte{1} + coilOff = []byte{0} +) + +// decomposeResponse extracts per-register/coil values from a Modbus read response. +// Response format: [funcCode, byteCount, data...] +// For registers (FC 0x03, 0x04): each register is 2 bytes. +// For coils/discrete inputs (FC 0x01, 0x02): each coil is 1 bit, stored as 1 byte (0 or 1). +func decomposeResponse(functionCode byte, quantity uint16, data []byte) [][]byte { + if len(data) < 2 { + return nil } + + payload := data[2:] // Skip funcCode and byteCount + + switch functionCode { + case modbus.FuncReadHoldingRegisters, modbus.FuncReadInputRegisters: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + offset := i * 2 + if int(offset+2) > len(payload) { + return nil + } + reg := make([]byte, 2) + copy(reg, payload[offset:offset+2]) + values[i] = reg + } + return values + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + byteIdx := i / 8 + bitIdx := i % 8 + if int(byteIdx) >= len(payload) { + return nil + } + if payload[byteIdx]&(1<= 2 { + resp[2+i*2] = v[0] + resp[2+i*2+1] = v[1] + } + } + return resp + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + byteCount := (quantity + 7) / 8 + resp := make([]byte, 2+byteCount) + resp[0] = functionCode + resp[1] = byte(byteCount) + for i, v := range values { + if len(v) > 0 && v[0] != 0 { + byteIdx := i / 8 + bitIdx := uint(i % 8) + resp[2+byteIdx] |= 1 << bitIdx + } + } + return resp + } + + return nil } func (p *Proxy) buildFakeWriteResponse(req *modbus.Request) []byte { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 99364cc..e643aec 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1,7 +1,9 @@ package proxy import ( + "bytes" "context" + "errors" "io" "log/slog" "testing" @@ -19,6 +21,12 @@ type mockClient struct { calls int } +func (m *mockClient) Connect() error { return nil } + +func (m *mockClient) Close() error { return nil } + +func (m *mockClient) Healthy() error { return nil } + func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, error) { m.calls++ return m.response, m.err @@ -26,7 +34,8 @@ func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, func TestProxy_HandleReadCacheHit(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - c := cache.New(time.Second) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ @@ -38,15 +47,17 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { cache: c, } - // Pre-populate cache - key := cache.Key(1, modbus.FuncReadHoldingRegisters, 0, 10) - c.Set(key, []byte{0x03, 0x14, 0x00, 0x01}) // Function code + byte count + data + // Pre-populate cache with per-register values + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, [][]byte{ + {0x00, 0x01}, + {0x00, 0x02}, + }) req := &modbus.Request{ SlaveID: 1, FunctionCode: modbus.FuncReadHoldingRegisters, Address: 0, - Quantity: 10, + Quantity: 2, } resp, err := p.HandleRequest(context.Background(), req) @@ -54,8 +65,122 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if string(resp) != string([]byte{0x03, 0x14, 0x00, 0x01}) { - t.Errorf("unexpected response: %v", resp) + // Expected assembled response: funcCode + byteCount + reg0 + reg1 + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestProxy_HandleReadMissFetchesAndCaches(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + upstream := &mockClient{ + response: []byte{0x03, 0x04, 0x00, 0x0A, 0x00, 0x0B}, + } + p := &Proxy{ + cfg: &config.Config{ + CacheTTL: time.Second, + CacheServeStale: false, + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + client: upstream, + cache: c, + } + + req := &modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncReadHoldingRegisters, + Address: 10, + Quantity: 2, + } + + resp, err := p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []byte{0x03, 0x04, 0x00, 0x0A, 0x00, 0x0B} + if !bytes.Equal(resp, expected) { + t.Fatalf("first response: expected %v, got %v", expected, resp) + } + if upstream.calls != 1 { + t.Fatalf("expected 1 upstream call after miss, got %d", upstream.calls) + } + + values, ok := c.GetRange(1, modbus.FuncReadHoldingRegisters, 10, 2) + if !ok { + t.Fatal("expected fetched response to be cached per register") + } + if !bytes.Equal(values[0], []byte{0x00, 0x0A}) || !bytes.Equal(values[1], []byte{0x00, 0x0B}) { + t.Fatalf("unexpected cached values: %v", values) + } + + // Change the upstream response. The second request should be served from cache, + // so the upstream should not be called again and the response should stay the same. + upstream.response = []byte{0x03, 0x04, 0x00, 0xFF, 0x00, 0xFF} + resp, err = p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error on cached read: %v", err) + } + if !bytes.Equal(resp, expected) { + t.Fatalf("cached response: expected %v, got %v", expected, resp) + } + if upstream.calls != 1 { + t.Fatalf("expected cached read to avoid upstream call, got %d calls", upstream.calls) + } +} + +func TestProxy_HandleReadServesStaleOnUpstreamError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(10*time.Millisecond, true) + defer c.Close() + + c.SetRange(1, modbus.FuncReadHoldingRegisters, 20, [][]byte{ + {0x00, 0x01}, + {0x00, 0x02}, + }) + time.Sleep(20 * time.Millisecond) + + upstreamErr := errors.New("upstream unavailable") + upstream := &mockClient{err: upstreamErr} + p := &Proxy{ + cfg: &config.Config{ + CacheTTL: 10 * time.Millisecond, + CacheServeStale: true, + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + client: upstream, + cache: c, + } + + req := &modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncReadHoldingRegisters, + Address: 20, + Quantity: 2, + } + + resp, err := p.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("expected stale response, got error: %v", err) + } + if upstream.calls != 1 { + t.Fatalf("expected one failed upstream call before serving stale, got %d", upstream.calls) + } + + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if !bytes.Equal(resp, expected) { + t.Fatalf("stale response: expected %v, got %v", expected, resp) } } @@ -73,12 +198,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := cache.New(time.Second, false) + defer c.Close() + p := &Proxy{ cfg: &config.Config{ ReadOnly: tt.mode, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -104,13 +232,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { func TestProxy_HandleUnknownFunction(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ ReadOnly: config.ReadOnlyOn, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -183,3 +313,214 @@ func TestProxy_BuildFakeWriteResponse(t *testing.T) { }) } } + +func TestDecomposeResponse_Registers(t *testing.T) { + // Response: FC 0x03, byteCount=4, reg0=0x0001, reg1=0x0002 + data := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + values := decomposeResponse(modbus.FuncReadHoldingRegisters, 2, data) + + if len(values) != 2 { + t.Fatalf("expected 2 values, got %d", len(values)) + } + if values[0][0] != 0x00 || values[0][1] != 0x01 { + t.Errorf("reg0: expected 0x0001, got 0x%02X%02X", values[0][0], values[0][1]) + } + if values[1][0] != 0x00 || values[1][1] != 0x02 { + t.Errorf("reg1: expected 0x0002, got 0x%02X%02X", values[1][0], values[1][1]) + } +} + +func TestDecomposeResponse_Coils(t *testing.T) { + // Response: FC 0x01, byteCount=2, coils 0-9 + // 0xCD = 1100_1101: coils 0,2,3,6,7 on + // 0x01 = 0000_0001: coil 8 on + data := []byte{0x01, 0x02, 0xCD, 0x01} + values := decomposeResponse(modbus.FuncReadCoils, 10, data) + + if len(values) != 10 { + t.Fatalf("expected 10 values, got %d", len(values)) + } + + expected := []byte{1, 0, 1, 1, 0, 0, 1, 1, 1, 0} + for i, exp := range expected { + if values[i][0] != exp { + t.Errorf("coil %d: expected %d, got %d", i, exp, values[i][0]) + } + } +} + +func TestAssembleResponse_Registers(t *testing.T) { + values := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + resp := assembleResponse(modbus.FuncReadHoldingRegisters, 2, values) + + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestAssembleResponse_Coils(t *testing.T) { + // Coils 0,2,3,6,7 on, 8 on — should produce 0xCD 0x01 + values := [][]byte{{1}, {0}, {1}, {1}, {0}, {0}, {1}, {1}, {1}, {0}} + resp := assembleResponse(modbus.FuncReadCoils, 10, values) + + expected := []byte{0x01, 0x02, 0xCD, 0x01} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestDecomposeAssemble_Roundtrip(t *testing.T) { + tests := []struct { + name string + funcCode byte + quantity uint16 + data []byte + }{ + { + name: "holding registers", + funcCode: modbus.FuncReadHoldingRegisters, + quantity: 3, + data: []byte{0x03, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, + }, + { + name: "input registers", + funcCode: modbus.FuncReadInputRegisters, + quantity: 2, + data: []byte{0x04, 0x04, 0xFF, 0xFF, 0x00, 0x00}, + }, + { + name: "coils", + funcCode: modbus.FuncReadCoils, + quantity: 10, + data: []byte{0x01, 0x02, 0xCD, 0x01}, + }, + { + name: "discrete inputs", + funcCode: modbus.FuncReadDiscreteInputs, + quantity: 8, + data: []byte{0x02, 0x01, 0xAC}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := decomposeResponse(tt.funcCode, tt.quantity, tt.data) + if values == nil { + t.Fatal("decomposeResponse returned nil") + } + + reassembled := assembleResponse(tt.funcCode, tt.quantity, values) + if len(reassembled) != len(tt.data) { + t.Fatalf("length mismatch: expected %d, got %d", len(tt.data), len(reassembled)) + } + for i := range tt.data { + if reassembled[i] != tt.data[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, tt.data[i], reassembled[i]) + } + } + }) + } +} + +func TestProxy_WriteInvalidatesOverlappingReads(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 (simulating a previous read of range 0-9) + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to register 5 — should invalidate register 5 + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteSingleRegister, + Address: 5, + Quantity: 1, + }) + + // Full range 0-9 should now miss (register 5 is gone) + _, ok := c.GetRange(1, modbus.FuncReadHoldingRegisters, 0, 10) + if ok { + t.Error("expected range miss after write invalidation of register 5") + } + + // Registers 0-4 and 6-9 should still be cached individually + for i := uint16(0); i < 10; i++ { + _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)) + if i == 5 { + if ok { + t.Error("register 5 should be invalidated") + } + } else { + if !ok { + t.Errorf("register %d should still be cached", i) + } + } + } +} + +func TestProxy_WriteInvalidatesMultipleRegisters(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to registers 3-5 (write multiple) + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteMultipleRegs, + Address: 3, + Quantity: 3, + }) + + // Registers 3,4,5 should be gone + for i := uint16(3); i <= 5; i++ { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); ok { + t.Errorf("register %d should be invalidated", i) + } + } + + // Registers 0,1,2,6,7,8,9 should still be cached + for _, i := range []uint16{0, 1, 2, 6, 7, 8, 9} { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); !ok { + t.Errorf("register %d should still be cached", i) + } + } +}