Skip to content

Commit 27075a1

Browse files
author
Anders Qvist
committed
Error handlers are responsible for HTTP response.
1 parent b13ec3c commit 27075a1

File tree

6 files changed

+153
-78
lines changed

6 files changed

+153
-78
lines changed

internal/oidctesting/tests.go

Lines changed: 80 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"fmt"
55
"net/http"
66
"net/http/httptest"
7-
"strings"
8-
"sync"
97
"testing"
108

119
"github.com/stretchr/testify/require"
@@ -313,65 +311,91 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
313311
op := optest.NewTesting(t)
314312
defer op.Close(t)
315313

316-
var info struct {
317-
sync.RWMutex
318-
description options.ErrorDescription
319-
err error
320-
}
321-
322-
setInfo := func(description options.ErrorDescription, err error) {
323-
info.Lock()
324-
info.description = description
325-
info.err = err
326-
info.Unlock()
327-
}
328-
329-
getInfo := func() (description options.ErrorDescription, err error) {
330-
info.RLock()
331-
defer info.RUnlock()
332-
return info.description, info.err
333-
}
334-
335-
errorHandler := func(description options.ErrorDescription, err error) {
336-
t.Logf("Description: %s\tError: %v", description, err)
337-
setInfo(description, err)
338-
}
339-
340-
opts := []options.Option{
341-
options.WithIssuer(op.GetURL(t)),
342-
options.WithRequiredAudience("test-client"),
343-
options.WithRequiredTokenType("JWT+AT"),
344-
options.WithErrorHandler(errorHandler),
314+
cases := []struct {
315+
testDescription string
316+
errorHandler options.ErrorHandler
317+
expectStatusCode int
318+
expectHeaders map[string]string
319+
expectBodyContains []byte
320+
}{
321+
{
322+
testDescription: "no output",
323+
errorHandler: func(desc options.ErrorDescription, err error) *options.Response { return nil },
324+
expectStatusCode: http.StatusBadRequest,
325+
expectHeaders: map[string]string{},
326+
expectBodyContains: []byte{},
327+
},
328+
{
329+
testDescription: "basic propagation",
330+
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
331+
return &options.Response{
332+
StatusCode: 418,
333+
Headers: map[string]string{},
334+
Body: []byte("badness"),
335+
}
336+
},
337+
expectStatusCode: http.StatusTeapot,
338+
expectHeaders: map[string]string{
339+
"Content-Type": "application/octet-stream",
340+
},
341+
expectBodyContains: []byte("bad"),
342+
},
343+
{
344+
testDescription: "additional header",
345+
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
346+
return &options.Response{
347+
StatusCode: 418,
348+
Headers: map[string]string{"some": "header"},
349+
Body: []byte("badness"),
350+
}
351+
},
352+
expectStatusCode: http.StatusTeapot,
353+
expectHeaders: map[string]string{
354+
"Some": "header",
355+
"Content-Type": "application/octet-stream",
356+
},
357+
expectBodyContains: []byte{},
358+
},
359+
{
360+
testDescription: "content type",
361+
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
362+
return &options.Response{
363+
StatusCode: 418,
364+
Headers: map[string]string{"content-type": "application/json"},
365+
Body: []byte("{}"),
366+
}
367+
},
368+
expectStatusCode: http.StatusTeapot,
369+
expectHeaders: map[string]string{
370+
"Content-Type": "application/json",
371+
},
372+
expectBodyContains: []byte("{}"),
373+
},
345374
}
375+
for i := range cases {
376+
c := cases[i]
377+
t.Logf("Test iteration %d: %s", i, c.testDescription)
378+
opts := []options.Option{
379+
options.WithIssuer(op.GetURL(t)),
380+
options.WithRequiredAudience("test-client"),
381+
options.WithRequiredTokenType("JWT+AT"),
382+
options.WithErrorHandler(c.errorHandler),
383+
}
346384

347-
oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
348-
require.NoError(t, err)
349-
350-
handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)
351-
352-
// Test without token
353-
reqNoAuth := httptest.NewRequest(http.MethodGet, "/", nil)
354-
recNoAuth := httptest.NewRecorder()
355-
handler.ServeHTTP(recNoAuth, reqNoAuth)
356-
357-
require.Equal(t, http.StatusBadRequest, recNoAuth.Result().StatusCode)
385+
oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
386+
require.NoError(t, err)
358387

359-
d, e := getInfo()
388+
handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)
360389

361-
if !strings.Contains(t.Name(), "OidcEchoJwt") {
362-
require.Equal(t, options.GetTokenErrorDescription, d)
363-
require.EqualError(t, e, "unable to extract token: Authorization header empty")
390+
req := httptest.NewRequest(http.MethodGet, "/", nil)
391+
res := httptest.NewRecorder()
392+
handler.ServeHTTP(res, req)
393+
require.Equal(t, c.expectStatusCode, res.Result().StatusCode)
394+
for k, v := range c.expectHeaders {
395+
require.Equal(t, []string{v}, res.Result().Header[k])
396+
}
397+
require.Subset(t, res.Body.Bytes(), c.expectBodyContains)
364398
}
365-
366-
// Test with fake token
367-
token := op.GetToken(t)
368-
token.AccessToken = "foobar"
369-
testHttpWithAuthenticationFailure(t, token, handler)
370-
371-
d, e = getInfo()
372-
373-
require.Equal(t, options.ParseTokenErrorDescription, d)
374-
require.EqualError(t, e, "unable to parse token signature: invalid compact serialization format: invalid number of segments")
375399
})
376400
}
377401

oidcecho/echo.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,23 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
1919
return toEchoMiddleware(h.ParseToken, setters...)
2020
}
2121

22-
func onError(errorHandler options.ErrorHandler, description options.ErrorDescription, err error) {
23-
if errorHandler != nil {
24-
errorHandler(description, err)
22+
func onError(c echo.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
23+
if errorHandler == nil {
24+
c.Logger().Error(err)
25+
return c.NoContent(statusCode)
2526
}
27+
response := errorHandler(description, err)
28+
if response == nil {
29+
c.Logger().Error(err)
30+
return c.NoContent(statusCode)
31+
}
32+
for k, v := range response.Headers {
33+
c.Response().Header().Set(k, v)
34+
}
35+
c.Response().Header().Set(echo.HeaderContentType, response.ContentType())
36+
c.Response().WriteHeader(response.StatusCode)
37+
_, err = c.Response().Write(response.Body)
38+
return err
2639
}
2740

2841
func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) echo.MiddlewareFunc {
@@ -34,14 +47,12 @@ func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...optio
3447

3548
tokenString, err := oidc.GetTokenString(c.Request().Header.Get, opts.TokenString)
3649
if err != nil {
37-
onError(opts.ErrorHandler, options.GetTokenErrorDescription, err)
38-
return echo.ErrBadRequest
50+
return onError(c, opts.ErrorHandler, echo.ErrBadRequest.Code, options.GetTokenErrorDescription, err)
3951
}
4052

4153
claims, err := parseToken(ctx, tokenString)
4254
if err != nil {
43-
onError(opts.ErrorHandler, options.ParseTokenErrorDescription, err)
44-
return echo.ErrUnauthorized
55+
return onError(c, opts.ErrorHandler, echo.ErrUnauthorized.Code, options.ParseTokenErrorDescription, err)
4556
}
4657
c.Set(string(opts.ClaimsContextKeyName), claims)
4758
return next(c)

oidcfiber/fiber.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,18 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
2020
}
2121

2222
func onError(c *fiber.Ctx, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
23-
if errorHandler != nil {
24-
errorHandler(description, err)
23+
if errorHandler == nil {
24+
return c.SendStatus(statusCode)
2525
}
26-
27-
return c.SendStatus(statusCode)
26+
response := errorHandler(description, err)
27+
if response == nil {
28+
return c.SendStatus(statusCode)
29+
}
30+
for k, v := range response.Headers {
31+
c.Response().Header.Set(k, v)
32+
}
33+
c.Set("Content-Type", response.ContentType())
34+
return c.Status(response.StatusCode).Send(response.Body)
2835
}
2936

3037
func toFiberHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) fiber.Handler {

oidcgin/gin.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
2020
return toGinHandler(oidcHandler.ParseToken, setters...)
2121
}
2222

23-
func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
24-
if errorHandler != nil {
25-
errorHandler(description, err)
23+
func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
24+
if errorHandler == nil {
25+
return c.AbortWithError(statusCode, err)
2626
}
27-
28-
//nolint:errcheck // false positive
29-
c.AbortWithError(statusCode, err)
27+
response := errorHandler(description, err)
28+
if response == nil {
29+
return c.AbortWithError(statusCode, err)
30+
}
31+
for k, v := range response.Headers {
32+
c.Header(k, v)
33+
}
34+
c.Data(response.StatusCode, response.ContentType(), response.Body)
35+
return nil
3036
}
3137

3238
func toGinHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) gin.HandlerFunc {

oidchttp/http.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,21 @@ func New[T any](h http.Handler, claimsValidationFn options.ClaimsValidationFn[T]
2121
}
2222

2323
func onError(w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
24-
if errorHandler != nil {
25-
errorHandler(description, err)
24+
if errorHandler == nil {
25+
w.WriteHeader(statusCode)
26+
return
2627
}
27-
28-
w.WriteHeader(statusCode)
28+
response := errorHandler(description, err)
29+
if response == nil {
30+
w.WriteHeader(statusCode)
31+
return
32+
}
33+
for k, v := range response.Headers {
34+
w.Header().Add(k, v)
35+
}
36+
w.Header().Set("Content-Type", response.ContentType())
37+
w.WriteHeader(response.StatusCode)
38+
w.Write(response.Body)
2939
}
3040

3141
func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], setters ...options.Option) http.Handler {

options/options.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ import (
55
"time"
66
)
77

8+
type Response struct {
9+
StatusCode int
10+
Headers map[string]string
11+
Body []byte
12+
}
13+
14+
// Return the content-type header from this response, or "applicatin/octet-stream"
15+
// as per HTTP standard.
16+
func (r *Response) ContentType() string {
17+
for k, v := range r.Headers {
18+
if http.CanonicalHeaderKey(k) == "Content-Type" {
19+
return v
20+
}
21+
}
22+
return "application/octet-stream"
23+
}
24+
825
// ClaimsValidationFn is a generic function to validate calims.
926
// If an error is returned, the claims failed the validation.
1027
// If `nil` is provided instead of a function when configuration the handler,
@@ -19,7 +36,7 @@ type ClaimsContextKeyName string
1936
const DefaultClaimsContextKeyName ClaimsContextKeyName = "claims"
2037

2138
// ErrorHandler is called by the middleware if not nil
22-
type ErrorHandler func(description ErrorDescription, err error)
39+
type ErrorHandler func(description ErrorDescription, err error) *Response
2340

2441
// ErrorDescription is used to pass the description of the error to ErrorHandler
2542
type ErrorDescription string

0 commit comments

Comments
 (0)