Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,22 @@ func (c *Client) logRequest(req *http.Request) {
for key, values := range req.Header {
c.logger.Debug("%s: %s", key, strings.Join(values, ", "))
}
if req.Body != nil {
var buf bytes.Buffer
body, _ := io.ReadAll(req.Body)
buf.Write(body)
req.Body = io.NopCloser(&buf)
c.logger.Debug("\n%s", string(body))
if req.Body == nil {
return
}
// Read via GetBody so the live body stays intact; only drain+restore when
// there's no GetBody (streaming bodies).
if req.GetBody != nil {
if rc, err := req.GetBody(); err == nil {
body, _ := io.ReadAll(rc)
rc.Close()
c.logger.Debug("\n%s", string(body))
return
}
}
body, _ := io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewReader(body))
c.logger.Debug("\n%s", string(body))
}

// logResponse logs the details of an HTTP response
Expand Down Expand Up @@ -268,14 +277,25 @@ func newRequest[T any](c *Client, ctx context.Context, method, path string, para
c.logger.Error("Error marshaling data: %+v, setting body to nil", err)
r.Body = nil
} else {
r.Body = io.NopCloser(bytes.NewReader(b))
setRetryableBody(r, b)
c.logger.Debug("Request body set with JSON: %s", string(b))
}
}

return r, nil
}

// setRetryableBody sets an in-memory body with ContentLength and GetBody. GetBody
// lets the HTTP/2 transport retry the request on GOAWAY/REFUSED_STREAM, which it
// can't do without a rewindable body.
func setRetryableBody(r *http.Request, b []byte) {
r.ContentLength = int64(len(b))
r.Body = io.NopCloser(bytes.NewReader(b))
r.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(b)), nil
}
}

func getFileContent(fileName string, fileContent io.Reader) (io.Reader, error) {
if fileContent != nil {
return fileContent, nil
Expand Down Expand Up @@ -432,8 +452,8 @@ func (c *Client) createMultipartRequest(r *http.Request, data any) (*http.Reques
return nil, stackWrap(err, "failed to close multipart writer")
}

// Update request body and content type
r.Body = io.NopCloser(&buf)
// buf isn't mutated after this, so sharing its slice is safe.
setRetryableBody(r, buf.Bytes())
r.Header.Set("Content-Type", writer.FormDataContentType())

c.logger.Debug("Created multipart request with file: %s", fileName)
Expand Down
216 changes: 216 additions & 0 deletions http_retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package getstream

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

type getBodyTestRequest struct {
Foo string `json:"foo"`
}

type getBodyTestResponse struct{}

const getBodyTestJSON = `{"foo":"bar"}`

// A JSON body must carry GetBody and an accurate ContentLength.
func TestNewRequest_JSONBody_SetsGetBodyAndContentLength(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))

req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil)
require.NoError(t, err)
require.NotNil(t, req.GetBody, "JSON body must be replayable")
require.Equal(t, int64(len(getBodyTestJSON)), req.ContentLength)

b, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, getBodyTestJSON, string(b))

// GetBody must be usable repeatedly and always yield the full payload.
for i := 0; i < 3; i++ {
rc, err := req.GetBody()
require.NoError(t, err)
gb, err := io.ReadAll(rc)
require.NoError(t, err)
require.NoError(t, rc.Close())
require.Equal(t, getBodyTestJSON, string(gb))
}
}

// After the live body is drained (as the first attempt writes it out), GetBody
// must still recover identical bytes for the retry.
func TestNewRequest_GetBody_RewindableAfterDrain(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))

req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil)
require.NoError(t, err)

drained, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, getBodyTestJSON, string(drained))

rc, err := req.GetBody()
require.NoError(t, err)
replay, err := io.ReadAll(rc)
require.NoError(t, err)
require.Equal(t, getBodyTestJSON, string(replay))
}

// Backward compat: the wire bytes must stay exactly json.Marshal(data).
func TestNewRequest_JSONBody_WireBytesUnchanged(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))
data := map[string]any{"field1": "value1", "field2": 2}

req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, data, nil)
require.NoError(t, err)

want, err := json.Marshal(data)
require.NoError(t, err)
got, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, string(want), string(got))
require.Equal(t, int64(len(want)), req.ContentLength)
}

// Multipart uploads must also be replayable via GetBody.
func TestNewRequest_Multipart_SetsRetryableBody(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "f.txt")
require.NoError(t, os.WriteFile(path, []byte("hello-file-content"), 0o600))

client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))
req, err := newRequest(client, context.Background(), http.MethodPost, "/upload", nil, &UploadFileRequest{File: PtrTo(path)}, nil)
require.NoError(t, err)
require.NotNil(t, req.GetBody, "multipart upload must be replayable")
require.True(t, strings.HasPrefix(req.Header.Get("Content-Type"), "multipart/form-data"))

first, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, int64(len(first)), req.ContentLength)
require.Contains(t, string(first), "hello-file-content")

rc, err := req.GetBody()
require.NoError(t, err)
second, err := io.ReadAll(rc)
require.NoError(t, err)
require.Equal(t, first, second, "GetBody must reproduce identical multipart bytes")
}

// TestNewRequest_GET_NoBody confirms GET requests are unchanged (no body, no GetBody).
func TestNewRequest_GET_NoBody(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))
var data any

req, err := newRequest(client, context.Background(), http.MethodGet, "/v1/x", nil, data, nil)
require.NoError(t, err)
require.Nil(t, req.Body)
require.Nil(t, req.GetBody)
}

// Arbitrary streaming readers stay non-rewindable (no GetBody) by design.
func TestNewRequest_StreamingReader_NoGetBody(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))

req, err := newRequest(client, context.Background(), http.MethodPut, "/v1/x", nil, strings.NewReader("raw"), nil)
require.NoError(t, err)
require.NotNil(t, req.Body)
require.Nil(t, req.GetBody, "arbitrary streaming readers cannot be made rewindable")

b, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, "raw", string(b))
}

// logRequest must not consume the live body or drop GetBody.
func TestLogRequest_PreservesBodyAndGetBody(t *testing.T) {
client, _ := newClient("k", "s", WithBaseUrl("https://api.example.com"))
req, err := newRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, nil)
require.NoError(t, err)

client.logRequest(req)

body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.Equal(t, getBodyTestJSON, string(body), "logRequest must not consume the live body")

rc, err := req.GetBody()
require.NoError(t, err)
gb, err := io.ReadAll(rc)
require.NoError(t, err)
require.Equal(t, getBodyTestJSON, string(gb))
}

// refusedStreamOnce models net/http's HTTP/2 retry contract: REFUSED_STREAM is
// retried only if the body can be rewound via GetBody. It consumes the body on
// the first attempt, then rewinds via GetBody (or errors if GetBody is nil).
type refusedStreamOnce struct {
attempts int
bodies [][]byte
}

func (f *refusedStreamOnce) Do(req *http.Request) (*http.Response, error) {
f.attempts++

var b []byte
if req.Body != nil {
b, _ = io.ReadAll(req.Body)
_ = req.Body.Close()
}
f.bodies = append(f.bodies, b)

if f.attempts == 1 {
if req.GetBody == nil {
return nil, errors.New("http2: Transport: cannot retry err [stream error: stream ID 1; REFUSED_STREAM; received from peer] after Request.Body was written; define Request.GetBody to avoid this error")
}
nb, err := req.GetBody()
if err != nil {
return nil, err
}
req.Body = nb
return f.Do(req)
}

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("{}")),
Header: make(http.Header),
}, nil
}

// A POST whose stream is refused once now succeeds, with an identical replayed body.
func TestMakeRequest_RetriesOnRefusedStream(t *testing.T) {
fake := &refusedStreamOnce{}
client, err := newClient("k", "s", WithBaseUrl("https://api.example.com"), WithHTTPClient(fake))
require.NoError(t, err)

_, err = MakeRequest(client, context.Background(), http.MethodPost, "/v1/x", nil, &getBodyTestRequest{Foo: "bar"}, &getBodyTestResponse{}, nil)
require.NoError(t, err)
require.Equal(t, 2, fake.attempts, "transport should have retried the refused stream once")
require.Len(t, fake.bodies, 2)
require.Equal(t, getBodyTestJSON, string(fake.bodies[0]))
require.Equal(t, fake.bodies[0], fake.bodies[1], "replayed body must be identical to the first attempt")
}

// A bodied request without GetBody (the pre-fix shape) still fails, so the
// positive test above isn't tautological.
func TestRefusedStreamTransport_FailsWithoutGetBody(t *testing.T) {
fake := &refusedStreamOnce{}
req, err := http.NewRequest(http.MethodPost, "https://api.example.com", io.NopCloser(bytes.NewReader([]byte(getBodyTestJSON))))
require.NoError(t, err)
require.Nil(t, req.GetBody, "guard: this request intentionally has no GetBody")

_, err = fake.Do(req)
require.Error(t, err)
require.Contains(t, err.Error(), "define Request.GetBody")
require.Equal(t, 1, fake.attempts)
}
Loading