Skip to content

Commit a777166

Browse files
committed
chore: add unit test for WithCodeResponseWriter
Signed-off-by: kevin <[email protected]>
1 parent 0be63c3 commit a777166

File tree

5 files changed

+43
-29
lines changed

5 files changed

+43
-29
lines changed

rest/engine.go

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,6 @@ func (ng *engine) addRoutes(r featuredRoutes) {
6767
ng.mightUpdateTimeout(r)
6868
}
6969

70-
func buildSSERoutes(routes []Route) []Route {
71-
for i, route := range routes {
72-
h := route.Handler
73-
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
74-
rc := http.NewResponseController(w)
75-
err := rc.SetWriteDeadline(time.Time{})
76-
if err != nil {
77-
logx.Errorf("set conn write deadline failed:%v", err)
78-
}
79-
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
80-
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
81-
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
82-
h(w, r)
83-
}
84-
}
85-
86-
return routes
87-
}
88-
8970
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
9071
verifier func(chain.Chain) chain.Chain) chain.Chain {
9172
if fr.jwt.enabled {
@@ -400,6 +381,27 @@ func (ng *engine) withNetworkTimeout() internal.StartOption {
400381
}
401382
}
402383

384+
func buildSSERoutes(routes []Route) []Route {
385+
for i, route := range routes {
386+
h := route.Handler
387+
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
388+
// remove the default write deadline set by http.Server,
389+
// because SSE requires the connection to be kept alive indefinitely.
390+
rc := http.NewResponseController(w)
391+
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
392+
logx.Errorf("set conn write deadline failed: %v", err)
393+
}
394+
395+
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
396+
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
397+
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
398+
h(w, r)
399+
}
400+
}
401+
402+
return routes
403+
}
404+
403405
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
404406
return func(next http.Handler) http.Handler {
405407
return ware(next.ServeHTTP)

rest/engine_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,7 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
578578

579579
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
580580
}
581+
582+
func ptrOfDuration(d time.Duration) *time.Duration {
583+
return &d
584+
}

rest/internal/response/withcoderesponsewriter.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
4949
return nil, nil, errors.New("server doesn't support hijacking")
5050
}
5151

52+
// Unwrap returns the underlying http.ResponseWriter.
53+
// This is used by http.ResponseController to unwrap the response writer.
54+
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
55+
return w.Writer
56+
}
57+
5258
// Write writes bytes into w.
5359
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
5460
return w.Writer.Write(bytes)
@@ -59,8 +65,3 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) {
5965
w.Writer.WriteHeader(code)
6066
w.Code = code
6167
}
62-
63-
// Unwrap returns the underlying ResponseWriter.
64-
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
65-
return w.Writer
66-
}

rest/internal/response/withcoderesponsewriter_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,15 @@ func TestWithCodeResponseWriter_Hijack(t *testing.T) {
4646
writer.Hijack()
4747
})
4848
}
49+
50+
func TestWithCodeResponseWriter_Unwrap(t *testing.T) {
51+
resp := httptest.NewRecorder()
52+
writer := NewWithCodeResponseWriter(resp)
53+
unwrapped := writer.Unwrap()
54+
assert.Equal(t, resp, unwrapped)
55+
56+
// Test with a nested WithCodeResponseWriter
57+
nestedWriter := NewWithCodeResponseWriter(writer)
58+
unwrappedNested := nestedWriter.Unwrap()
59+
assert.Equal(t, resp, unwrappedNested)
60+
}

rest/server.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ func WithSignature(signature SignatureConf) RouteOption {
293293
func WithSSE() RouteOption {
294294
return func(r *featuredRoutes) {
295295
r.sse = true
296-
r.timeout = ptrOfDuration(0)
297296
}
298297
}
299298

@@ -335,10 +334,6 @@ func handleError(err error) {
335334
panic(err)
336335
}
337336

338-
func ptrOfDuration(d time.Duration) *time.Duration {
339-
return &d
340-
}
341-
342337
func validateSecret(secret string) {
343338
if len(secret) < 8 {
344339
panic("secret's length can't be less than 8")

0 commit comments

Comments
 (0)