Skip to content

Commit 7a48649

Browse files
Request buffering
Adds the ability to buffer requests before forwarding them to the target server. Buffering is not enabled by default, but can be controlled through two new deploy options: - `buffer-requests` - set this flag to enable buffering (default is false) - `buffer-memory` - max size (in bytes) to buffer into memory, when buffering is enabled. Defaults to 1MB. Once this limit is exceeded for a request, the remainder of its body be buffered to a tempfile on disk. The existing `max-request-body` still limits the total size of a request body. When buffering is enabled, this effectively limits the size of the tempfiles that will be created when exceeding buffer memory. `max-request-body` now defaults to 1GB. Requests with bodies larger than `max-request-body` will result in a `413 (Request Entity Too Large)` response.
1 parent 1662d36 commit 7a48649

13 files changed

+491
-73
lines changed

integration/slowclient/main.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"os"
8+
"strings"
9+
"time"
10+
)
11+
12+
const (
13+
ChunkSize = 1024
14+
ChunkTime = 100 * time.Millisecond
15+
Chunks = 30
16+
)
17+
18+
func main() {
19+
if len(os.Args) < 2 {
20+
fmt.Printf("Usage: %s <url>\n", os.Args[0])
21+
os.Exit(1)
22+
}
23+
24+
pr, pw := io.Pipe()
25+
26+
go func() {
27+
defer pw.Close()
28+
chunk := strings.Repeat("a", ChunkSize-1) + "\n"
29+
30+
for i := 0; i < Chunks; i++ {
31+
pw.Write([]byte(chunk))
32+
fmt.Printf("%d of %d\n", i, Chunks)
33+
time.Sleep(ChunkTime)
34+
}
35+
}()
36+
37+
req, err := http.NewRequest("POST", os.Args[1], pr)
38+
if err != nil {
39+
panic(err)
40+
}
41+
req.TransferEncoding = []string{"chunked"}
42+
resp, err := http.DefaultClient.Do(req)
43+
if err != nil {
44+
panic(err)
45+
}
46+
47+
fmt.Println("Response status:", resp.Status)
48+
}

integration/upstream/main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,26 @@ package main
22

33
import (
44
"fmt"
5+
"io"
56
"log/slog"
67
"net/http"
78
"os"
9+
"time"
810
)
911

1012
func newHandler(host string) http.HandlerFunc {
1113
return func(w http.ResponseWriter, r *http.Request) {
14+
if r.URL.Path == "/up" {
15+
w.WriteHeader(http.StatusOK)
16+
return
17+
}
18+
19+
started := time.Now()
1220
slog.Info("Request", "host", host, "request_id", r.Header.Get("X-Request-ID"), "method", r.Method, "url", r.URL)
1321

22+
io.Copy(io.Discard, r.Body)
23+
slog.Info("Read body", "duration", time.Since(started))
24+
1425
w.Header().Add("Content-Type", "text/html")
1526
fmt.Fprintf(w, "<body>Hello from %s</body>\n", host)
1627
}

internal/cmd/deploy.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,22 @@ func newDeployCommand() *deployCommand {
2828
}
2929

3030
deployCommand.cmd.Flags().StringVar(&deployCommand.args.TargetURL, "target", "", "Target host to deploy")
31+
deployCommand.cmd.Flags().StringVar(&deployCommand.args.Host, "host", "", "Host to serve this target on (empty for wildcard)")
32+
3133
deployCommand.cmd.Flags().BoolVar(&deployCommand.tls, "tls", false, "Configure TLS for this target (requires a non-empty host)")
3234
deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environmnent for certificate provisioning")
35+
3336
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
3437
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target")
35-
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.ServiceOptions.TargetTimeout, "target-timeout", server.DefaultTargetTimeout, "Maximum time to wait for the target server to respond when serving requests")
3638
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.ServiceOptions.HealthCheckConfig.Interval, "health-check-interval", server.DefaultHealthCheckInterval, "Interval between health checks")
3739
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.ServiceOptions.HealthCheckConfig.Timeout, "health-check-timeout", server.DefaultHealthCheckTimeout, "Time each health check must complete in")
3840
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.HealthCheckConfig.Path, "health-check-path", server.DefaultHealthCheckPath, "Path to check for health")
39-
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.ServiceOptions.MaxRequestBodySize, "max-request-body", 0, "Max size of request body (default of 0 means unlimited)")
40-
deployCommand.cmd.Flags().StringVar(&deployCommand.args.Host, "host", "", "Host to serve this target on (empty for wildcard)")
41+
42+
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.ServiceOptions.TargetTimeout, "target-timeout", server.DefaultTargetTimeout, "Maximum time to wait for the target server to respond when serving requests")
43+
44+
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.BufferRequests, "buffer-requests", false, "Enable request buffering")
45+
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.ServiceOptions.MaxRequestMemoryBufferSize, "buffer-memory", server.DefaultMaxRequestMemoryBufferSize, "Max size of request memory buffer")
46+
deployCommand.cmd.Flags().Int64Var(&deployCommand.args.ServiceOptions.MaxRequestBodySize, "max-request-body", server.DefaultMaxRequestBodySize, "Max size of request body")
4147

4248
deployCommand.cmd.MarkFlagRequired("target")
4349

internal/server/buffer.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package server
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
"log/slog"
8+
"os"
9+
)
10+
11+
var (
12+
ErrMaximumSizeExceeded = errors.New("maximum size exceeded")
13+
)
14+
15+
type BufferReadCloser struct {
16+
maxBytes int64
17+
maxMemBytes int64
18+
19+
memoryBuffer bytes.Buffer
20+
diskBuffer *os.File
21+
multiReader io.Reader
22+
}
23+
24+
func NewBufferReadCloser(r io.ReadCloser, maxBytes, maxMemBytes int64) (*BufferReadCloser, error) {
25+
brc := &BufferReadCloser{
26+
maxBytes: maxBytes,
27+
maxMemBytes: maxMemBytes,
28+
}
29+
30+
err := brc.populate(r)
31+
32+
return brc, err
33+
}
34+
35+
func (b *BufferReadCloser) Read(p []byte) (n int, err error) {
36+
return b.multiReader.Read(p)
37+
}
38+
39+
func (b *BufferReadCloser) Close() error {
40+
if b.diskBuffer != nil {
41+
b.diskBuffer.Close()
42+
os.Remove(b.diskBuffer.Name())
43+
slog.Debug("Buffer: removing spill", "file", b.diskBuffer.Name())
44+
}
45+
return nil
46+
}
47+
48+
func (b *BufferReadCloser) populate(r io.ReadCloser) error {
49+
defer r.Close()
50+
51+
moreDataRemaining, err := b.populateMemoryBuffer(r)
52+
if err != nil {
53+
return err
54+
}
55+
56+
if !moreDataRemaining {
57+
b.multiReader = &b.memoryBuffer
58+
return nil
59+
}
60+
61+
err = b.populateDiskBuffer(r)
62+
if err != nil {
63+
return err
64+
}
65+
66+
b.multiReader = io.MultiReader(&b.memoryBuffer, b.diskBuffer)
67+
return nil
68+
}
69+
70+
func (b *BufferReadCloser) populateMemoryBuffer(r io.ReadCloser) (bool, error) {
71+
limitReader := io.LimitReader(r, b.maxMemBytes)
72+
copied, err := b.memoryBuffer.ReadFrom(limitReader)
73+
if err != nil {
74+
return false, err
75+
}
76+
77+
moreDataRemaining := copied == b.maxMemBytes
78+
return moreDataRemaining, nil
79+
}
80+
81+
func (b *BufferReadCloser) populateDiskBuffer(r io.ReadCloser) error {
82+
var err error
83+
84+
b.diskBuffer, err = os.CreateTemp("", "proxy-buffer")
85+
if err != nil {
86+
return err
87+
}
88+
89+
slog.Debug("Buffer: spilling request to disk", "file", b.diskBuffer.Name())
90+
91+
maxDiskBytes := b.maxBytes - b.maxMemBytes
92+
limitReader := io.LimitReader(r, maxDiskBytes)
93+
copied, err := io.Copy(b.diskBuffer, limitReader)
94+
if err != nil {
95+
return err
96+
}
97+
98+
if copied == maxDiskBytes {
99+
b.Close()
100+
return ErrMaximumSizeExceeded
101+
}
102+
103+
b.diskBuffer.Seek(0, 0)
104+
return err
105+
}

internal/server/buffer_middleware.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package server
2+
3+
import (
4+
"log/slog"
5+
"net/http"
6+
)
7+
8+
type BufferMiddleware struct {
9+
maxBytes int64
10+
maxMemBytes int64
11+
next http.Handler
12+
}
13+
14+
func WithBufferMiddleware(maxBytes, maxMemBytes int64, next http.Handler) http.Handler {
15+
return &BufferMiddleware{
16+
maxBytes: maxBytes,
17+
maxMemBytes: maxMemBytes,
18+
next: next,
19+
}
20+
}
21+
22+
func (h *BufferMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
23+
buffer, err := NewBufferReadCloser(r.Body, h.maxBytes, h.maxMemBytes)
24+
if err != nil {
25+
if err == ErrMaximumSizeExceeded {
26+
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
27+
} else {
28+
slog.Error("Error buffering request", "path", r.URL.Path, "error", err)
29+
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
30+
}
31+
return
32+
}
33+
34+
r.Body = buffer
35+
h.next.ServeHTTP(w, r)
36+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package server
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"net/http/httptest"
7+
"strings"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestBufferMiddleware(t *testing.T) {
14+
sendRequest := func(body string) *httptest.ResponseRecorder {
15+
middleware := WithBufferMiddleware(8, 4, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
io.Copy(w, r.Body)
17+
}))
18+
19+
req := httptest.NewRequest("POST", "http://app.example.com/somepath", strings.NewReader(body))
20+
rec := httptest.NewRecorder()
21+
22+
middleware.ServeHTTP(rec, req)
23+
return rec
24+
}
25+
26+
t.Run("success", func(t *testing.T) {
27+
w := sendRequest("hello")
28+
29+
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
30+
assert.Equal(t, "hello", w.Body.String())
31+
})
32+
33+
t.Run("body too large", func(t *testing.T) {
34+
w := sendRequest("this request body is much too large")
35+
36+
assert.Equal(t, http.StatusRequestEntityTooLarge, w.Result().StatusCode)
37+
})
38+
}

internal/server/buffer_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package server
2+
3+
import (
4+
"io"
5+
"strings"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestBufferReadCloser_WithinMemoryLimits(t *testing.T) {
13+
r := io.NopCloser(strings.NewReader("Hello, World!"))
14+
brc, err := NewBufferReadCloser(r, 2048, 1024)
15+
16+
require.NoError(t, err)
17+
assert.Equal(t, "Hello, World!", brc.memoryBuffer.String())
18+
19+
result, err := io.ReadAll(brc)
20+
require.NoError(t, err)
21+
assert.Equal(t, "Hello, World!", string(result))
22+
}
23+
24+
func TestBufferReadCloser_ExceedsMemoryLimits(t *testing.T) {
25+
r := io.NopCloser(strings.NewReader("Hello, World!"))
26+
brc, err := NewBufferReadCloser(r, 1024, 5)
27+
28+
require.NoError(t, err)
29+
assert.Equal(t, "Hello", brc.memoryBuffer.String())
30+
31+
result, err := io.ReadAll(brc)
32+
require.NoError(t, err)
33+
assert.Equal(t, "Hello, World!", string(result))
34+
}
35+
36+
func TestBufferReadCloser_ExceedsMemoryAndDiskLimits(t *testing.T) {
37+
r := io.NopCloser(strings.NewReader("Hello, World!"))
38+
_, err := NewBufferReadCloser(r, 8, 5)
39+
40+
require.Equal(t, ErrMaximumSizeExceeded, err)
41+
}
42+
43+
func TestBufferReadCloser_EmptyReader(t *testing.T) {
44+
r := io.NopCloser(strings.NewReader(""))
45+
brc, err := NewBufferReadCloser(r, 2048, 1024)
46+
47+
require.NoError(t, err)
48+
49+
result, err := io.ReadAll(brc)
50+
require.NoError(t, err)
51+
assert.Empty(t, result)
52+
}

internal/server/router.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
8484
func (r *Router) SetServiceTarget(name string, host string, targetURL string, options ServiceOptions, deployTimeout time.Duration, drainTimeout time.Duration) error {
8585
slog.Info("Deploying", "service", name, "host", host, "target", targetURL, "tls", options.RequireTLS())
8686

87-
target, err := NewTarget(targetURL, options.HealthCheckConfig, options.TargetTimeout)
87+
targetOptions := TargetOptions{
88+
HealthCheckConfig: options.HealthCheckConfig,
89+
ResponseTimeout: options.TargetTimeout,
90+
BufferRequests: options.BufferRequests,
91+
MaxRequestMemoryBufferSize: options.MaxRequestMemoryBufferSize,
92+
MaxRequestBodySize: options.MaxRequestBodySize,
93+
}
94+
95+
target, err := NewTarget(targetURL, targetOptions)
8896
if err != nil {
8997
return err
9098
}

0 commit comments

Comments
 (0)