Skip to content

Commit 28e359e

Browse files
committed
Small refactor + tests
1 parent 5fc121e commit 28e359e

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

force-ssl-heroku.go renamed to forcessl.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ import (
55
"os"
66
)
77

8+
var getenv = os.Getenv
9+
10+
const (
11+
xForwardedProtoHeader = "x-forwarded-proto"
12+
goEnviron = "GO_ENV"
13+
)
14+
815
func ForceSsl(next http.Handler) http.Handler {
916
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
10-
if os.Getenv("GO_ENV") == "production" {
11-
if r.Header.Get("x-forwarded-proto") != "https" {
17+
if getenv(goEnviron) == "production" {
18+
if r.Header.Get(xForwardedProtoHeader) != "https" {
1219
sslUrl := "https://" + r.Host + r.RequestURI
1320
http.Redirect(w, r, sslUrl, http.StatusTemporaryRedirect)
1421
return
@@ -17,5 +24,4 @@ func ForceSsl(next http.Handler) http.Handler {
1724

1825
next.ServeHTTP(w, r)
1926
})
20-
2127
}

forcessl_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package forcesslheroku
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
)
8+
9+
var testCases = []struct {
10+
goEnv string
11+
proto string
12+
expectLoc string
13+
}{
14+
{goEnv: "production", proto: "http",
15+
expectLoc: "https://example.com/test"},
16+
{goEnv: "production", proto: "https"},
17+
{goEnv: "test", proto: "http"},
18+
{goEnv: "test", proto: "https"},
19+
}
20+
21+
func TestForceSsl(t *testing.T) {
22+
noopHandler := func(w http.ResponseWriter, r *http.Request) {}
23+
forceSsl := ForceSsl(http.HandlerFunc(noopHandler))
24+
25+
for _, tt := range testCases {
26+
getenv = func(key string) string {
27+
switch key {
28+
case goEnviron:
29+
return tt.goEnv
30+
default:
31+
return ""
32+
}
33+
}
34+
35+
t.Run(tt.goEnv+"_"+tt.proto, func(t *testing.T) {
36+
req := httptest.NewRequest("", "/test", nil)
37+
req.Header.Add(xForwardedProtoHeader, tt.proto)
38+
39+
res := httptest.NewRecorder()
40+
forceSsl.ServeHTTP(res, req)
41+
42+
if location := res.Header().Get("Location"); location != tt.expectLoc {
43+
t.Errorf("expected Location header '%s', got '%s'",
44+
tt.expectLoc, location)
45+
}
46+
})
47+
}
48+
}
49+

0 commit comments

Comments
 (0)