diff --git a/go.mod b/go.mod index 386b3931..860d6fa0 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/segmentio/encoding v0.5.4 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/oauth2 v0.35.0 + golang.org/x/time v0.15.0 golang.org/x/tools v0.42.0 ) diff --git a/go.sum b/go.sum index ea18cedd..377a7b11 100644 --- a/go.sum +++ b/go.sum @@ -14,5 +14,7 @@ golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/mcp/logging.go b/mcp/logging.go index b1bd82b1..e77c6e11 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -13,6 +13,8 @@ import ( "slices" "sync" "time" + + "golang.org/x/time/rate" ) // Logging levels. @@ -83,10 +85,10 @@ type LoggingHandler struct { // Ensures that the buffer reset is atomic with the write (see Handle). // A pointer so that clones share the mutex. See // https://github.com/golang/example/blob/master/slog-handler-guide/README.md#getting-the-mutex-right. - mu *sync.Mutex - lastMessageSent time.Time // for rate-limiting - buf *bytes.Buffer - handler slog.Handler + mu *sync.Mutex + limiter *rate.Limiter // for rate-limiting + buf *bytes.Buffer + handler slog.Handler } // ensureLogger returns l if non-nil, otherwise a discard logger. @@ -118,6 +120,9 @@ func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingH } if opts != nil { lh.opts = *opts + if opts.MinInterval > 0 { + lh.limiter = rate.NewLimiter(rate.Every(opts.MinInterval), 1) + } } return lh } @@ -157,11 +162,7 @@ func (h *LoggingHandler) Handle(ctx context.Context, r slog.Record) error { func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { // Observe the rate limit. - // TODO(jba): use golang.org/x/time/rate. - h.mu.Lock() - skip := time.Since(h.lastMessageSent) < h.opts.MinInterval - h.mu.Unlock() - if skip { + if h.limiter != nil && !h.limiter.Allow() { return nil } @@ -184,10 +185,6 @@ func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { return err } - h.mu.Lock() - h.lastMessageSent = time.Now() - h.mu.Unlock() - params := &LoggingMessageParams{ Logger: h.opts.LoggerName, Level: slogLevelToMCP(r.Level),