From bee94d64e0d8144175418c395353d49984216f19 Mon Sep 17 00:00:00 2001 From: kmitrovv Date: Tue, 30 Jun 2026 11:18:31 +0200 Subject: [PATCH] fix: retry on refused stream --- http.go | 38 ++++++-- http_retry_test.go | 216 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+), 9 deletions(-) create mode 100644 http_retry_test.go diff --git a/http.go b/http.go index b044caf..c155c1f 100644 --- a/http.go +++ b/http.go @@ -23,13 +23,22 @@ func (c *Client) logRequest(req *http.Request) { for key, values := range req.Header { c.logger.Debug("%s: %s", key, strings.Join(values, ", ")) } - if req.Body != nil { - var buf bytes.Buffer - body, _ := io.ReadAll(req.Body) - buf.Write(body) - req.Body = io.NopCloser(&buf) - c.logger.Debug("\n%s", string(body)) + if req.Body == nil { + return + } + // Read via GetBody so the live body stays intact; only drain+restore when + // there's no GetBody (streaming bodies). + if req.GetBody != nil { + if rc, err := req.GetBody(); err == nil { + body, _ := io.ReadAll(rc) + rc.Close() + c.logger.Debug("\n%s", string(body)) + return + } } + body, _ := io.ReadAll(req.Body) + req.Body = io.NopCloser(bytes.NewReader(body)) + c.logger.Debug("\n%s", string(body)) } // logResponse logs the details of an HTTP response @@ -268,7 +277,7 @@ func newRequest[T any](c *Client, ctx context.Context, method, path string, para c.logger.Error("Error marshaling data: %+v, setting body to nil", err) r.Body = nil } else { - r.Body = io.NopCloser(bytes.NewReader(b)) + setRetryableBody(r, b) c.logger.Debug("Request body set with JSON: %s", string(b)) } } @@ -276,6 +285,17 @@ func newRequest[T any](c *Client, ctx context.Context, method, path string, para return r, nil } +// setRetryableBody sets an in-memory body with ContentLength and GetBody. GetBody +// lets the HTTP/2 transport retry the request on GOAWAY/REFUSED_STREAM, which it +// can't do without a rewindable body. +func setRetryableBody(r *http.Request, b []byte) { + r.ContentLength = int64(len(b)) + r.Body = io.NopCloser(bytes.NewReader(b)) + r.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(b)), nil + } +} + func getFileContent(fileName string, fileContent io.Reader) (io.Reader, error) { if fileContent != nil { return fileContent, nil @@ -432,8 +452,8 @@ func (c *Client) createMultipartRequest(r *http.Request, data any) (*http.Reques return nil, stackWrap(err, "failed to close multipart writer") } - // Update request body and content type - r.Body = io.NopCloser(&buf) + // buf isn't mutated after this, so sharing its slice is safe. + setRetryableBody(r, buf.Bytes()) r.Header.Set("Content-Type", writer.FormDataContentType()) c.logger.Debug("Created multipart request with file: %s", fileName) diff --git a/http_retry_test.go b/http_retry_test.go new file mode 100644 index 0000000..520246c --- /dev/null +++ b/http_retry_test.go @@ -0,0 +1,216 @@ +package getstream + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type getBodyTestRequest struct { + Foo string `json:"foo"` +} + +type getBodyTestResponse struct{} + +const getBodyTestJSON = `{"foo":"bar"}` + +// A JSON body must carry GetBody and an accurate ContentLength. +func TestNewRequest_JSONBody_SetsGetBodyAndContentLength(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + + req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil) + require.NoError(t, err) + require.NotNil(t, req.GetBody, "JSON body must be replayable") + require.Equal(t, int64(len(getBodyTestJSON)), req.ContentLength) + + b, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, getBodyTestJSON, string(b)) + + // GetBody must be usable repeatedly and always yield the full payload. + for i := 0; i < 3; i++ { + rc, err := req.GetBody() + require.NoError(t, err) + gb, err := io.ReadAll(rc) + require.NoError(t, err) + require.NoError(t, rc.Close()) + require.Equal(t, getBodyTestJSON, string(gb)) + } +} + +// After the live body is drained (as the first attempt writes it out), GetBody +// must still recover identical bytes for the retry. +func TestNewRequest_GetBody_RewindableAfterDrain(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + + req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil) + require.NoError(t, err) + + drained, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, getBodyTestJSON, string(drained)) + + rc, err := req.GetBody() + require.NoError(t, err) + replay, err := io.ReadAll(rc) + require.NoError(t, err) + require.Equal(t, getBodyTestJSON, string(replay)) +} + +// Backward compat: the wire bytes must stay exactly json.Marshal(data). +func TestNewRequest_JSONBody_WireBytesUnchanged(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + data := map[string]any{"field1": "value1", "field2": 2} + + req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, data, nil) + require.NoError(t, err) + + want, err := json.Marshal(data) + require.NoError(t, err) + got, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, string(want), string(got)) + require.Equal(t, int64(len(want)), req.ContentLength) +} + +// Multipart uploads must also be replayable via GetBody. +func TestNewRequest_Multipart_SetsRetryableBody(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f.txt") + require.NoError(t, os.WriteFile(path, []byte("hello-file-content"), 0o600)) + + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + req, err := newRequest(client, context.Background(), http.MethodPost, "/upload", nil, &UploadFileRequest{File: PtrTo(path)}, nil) + require.NoError(t, err) + require.NotNil(t, req.GetBody, "multipart upload must be replayable") + require.True(t, strings.HasPrefix(req.Header.Get("Content-Type"), "multipart/form-data")) + + first, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, int64(len(first)), req.ContentLength) + require.Contains(t, string(first), "hello-file-content") + + rc, err := req.GetBody() + require.NoError(t, err) + second, err := io.ReadAll(rc) + require.NoError(t, err) + require.Equal(t, first, second, "GetBody must reproduce identical multipart bytes") +} + +// TestNewRequest_GET_NoBody confirms GET requests are unchanged (no body, no GetBody). +func TestNewRequest_GET_NoBody(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + var data any + + req, err := newRequest(client, context.Background(), http.MethodGet, "/v1/x", nil, data, nil) + require.NoError(t, err) + require.Nil(t, req.Body) + require.Nil(t, req.GetBody) +} + +// Arbitrary streaming readers stay non-rewindable (no GetBody) by design. +func TestNewRequest_StreamingReader_NoGetBody(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + + req, err := newRequest(client, context.Background(), http.MethodPut, "/v1/x", nil, strings.NewReader("raw"), nil) + require.NoError(t, err) + require.NotNil(t, req.Body) + require.Nil(t, req.GetBody, "arbitrary streaming readers cannot be made rewindable") + + b, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, "raw", string(b)) +} + +// logRequest must not consume the live body or drop GetBody. +func TestLogRequest_PreservesBodyAndGetBody(t *testing.T) { + client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com")) + req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil) + require.NoError(t, err) + + client.logRequest(req) + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, getBodyTestJSON, string(body), "logRequest must not consume the live body") + + rc, err := req.GetBody() + require.NoError(t, err) + gb, err := io.ReadAll(rc) + require.NoError(t, err) + require.Equal(t, getBodyTestJSON, string(gb)) +} + +// refusedStreamOnce models net/http's HTTP/2 retry contract: REFUSED_STREAM is +// retried only if the body can be rewound via GetBody. It consumes the body on +// the first attempt, then rewinds via GetBody (or errors if GetBody is nil). +type refusedStreamOnce struct { + attempts int + bodies [][]byte +} + +func (f *refusedStreamOnce) Do(req *http.Request) (*http.Response, error) { + f.attempts++ + + var b []byte + if req.Body != nil { + b, _ = io.ReadAll(req.Body) + _ = req.Body.Close() + } + f.bodies = append(f.bodies, b) + + if f.attempts == 1 { + if req.GetBody == nil { + return nil, errors.New("http2: Transport: cannot retry err [stream error: stream ID 1; REFUSED_STREAM; received from peer] after Request.Body was written; define Request.GetBody to avoid this error") + } + nb, err := req.GetBody() + if err != nil { + return nil, err + } + req.Body = nb + return f.Do(req) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("{}")), + Header: make(http.Header), + }, nil +} + +// A POST whose stream is refused once now succeeds, with an identical replayed body. +func TestMakeRequest_RetriesOnRefusedStream(t *testing.T) { + fake := &refusedStreamOnce{} + client, err := newClient("k", "s", WithBaseUrl("https://api.example.com"), WithHTTPClient(fake)) + require.NoError(t, err) + + _, err = MakeRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, &getBodyTestResponse{}, nil) + require.NoError(t, err) + require.Equal(t, 2, fake.attempts, "transport should have retried the refused stream once") + require.Len(t, fake.bodies, 2) + require.Equal(t, getBodyTestJSON, string(fake.bodies[0])) + require.Equal(t, fake.bodies[0], fake.bodies[1], "replayed body must be identical to the first attempt") +} + +// A bodied request without GetBody (the pre-fix shape) still fails, so the +// positive test above isn't tautological. +func TestRefusedStreamTransport_FailsWithoutGetBody(t *testing.T) { + fake := &refusedStreamOnce{} + req, err := http.NewRequest(http.MethodPost, "https://api.example.com", io.NopCloser(bytes.NewReader([]byte(getBodyTestJSON)))) + require.NoError(t, err) + require.Nil(t, req.GetBody, "guard: this request intentionally has no GetBody") + + _, err = fake.Do(req) + require.Error(t, err) + require.Contains(t, err.Error(), "define Request.GetBody") + require.Equal(t, 1, fake.attempts) +}