Skip to content

Commit 88c83b7

Browse files
author
Bittrance
authored
Merge pull request #252 from XenitAB/errorhandler-does-http
Error handlers are responsible for HTTP response
2 parents b13ec3c + fa3ed69 commit 88c83b7

File tree

7 files changed

+227
-83
lines changed

7 files changed

+227
-83
lines changed

README.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,35 @@ oidcHandler := oidcgin.New(
319319

320320
### Custom error handler
321321

322-
It is possible to add a custom function to handle errors. It will not be possible to change anything using it, but you will be able to add logic for logging as an example.
322+
It is possible to add a custom function to handle errors. The error handler can return an `options.Response` which will be rendered by the middleware. Returning `nil` will result in a default 400/401 error.
323323

324324
```go
325-
errorHandler := func(description options.ErrorDescription, err error) {
326-
fmt.Printf("Description: %s\tError: %v\n", description, err)
325+
type Message struct {
326+
Message string `json:"message"`
327+
Url string `json:"url"`
328+
}
329+
330+
func errorHandler(ctx context.Context, oidcErr *options.OidcError) *options.Response {
331+
message := Message{
332+
Message: string(oidcErr.Status),
333+
Url: oidcErr.Url.String(),
334+
}
335+
var headers map[string]string
336+
json, err := json.Marshal(message)
337+
if err != nil {
338+
headers["Content-Type"] = "text/plain"
339+
return &options.Response{
340+
StatusCode: 500,
341+
Headers: headers,
342+
Body: []byte("Internal encoding failure\r\n"),
343+
}
344+
}
345+
headers["Content-Type"] = "text/plain"
346+
return &options.Response{
347+
StatusCode: 418,
348+
Headers: headers,
349+
Body: json,
350+
}
327351
}
328352

329353
oidcHandler := oidcgin.New(
@@ -334,6 +358,8 @@ oidcHandler := oidcgin.New(
334358
)
335359
```
336360

361+
This error handling interface was changed in v0.0.42. The previous interface was `func(description ErrorDescription, err error)`. In order to retain the same behavior, you need to update your error handler to read `desctiption` and `err` from `oidcErr` and return `nil`.
362+
337363
### Testing with the middleware enabled
338364

339365
There's a small package that simulates an OpenID Provider that can be used with tests.

internal/oidctesting/tests.go

Lines changed: 81 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package oidctesting
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/http/httptest"
7-
"strings"
8-
"sync"
98
"testing"
109

1110
"github.com/stretchr/testify/require"
@@ -313,65 +312,91 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
313312
op := optest.NewTesting(t)
314313
defer op.Close(t)
315314

316-
var info struct {
317-
sync.RWMutex
318-
description options.ErrorDescription
319-
err error
320-
}
321-
322-
setInfo := func(description options.ErrorDescription, err error) {
323-
info.Lock()
324-
info.description = description
325-
info.err = err
326-
info.Unlock()
327-
}
328-
329-
getInfo := func() (description options.ErrorDescription, err error) {
330-
info.RLock()
331-
defer info.RUnlock()
332-
return info.description, info.err
333-
}
334-
335-
errorHandler := func(description options.ErrorDescription, err error) {
336-
t.Logf("Description: %s\tError: %v", description, err)
337-
setInfo(description, err)
338-
}
339-
340-
opts := []options.Option{
341-
options.WithIssuer(op.GetURL(t)),
342-
options.WithRequiredAudience("test-client"),
343-
options.WithRequiredTokenType("JWT+AT"),
344-
options.WithErrorHandler(errorHandler),
315+
cases := []struct {
316+
testDescription string
317+
errorHandler options.ErrorHandler
318+
expectStatusCode int
319+
expectHeaders map[string]string
320+
expectBodyContains []byte
321+
}{
322+
{
323+
testDescription: "no output",
324+
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response { return nil },
325+
expectStatusCode: http.StatusBadRequest,
326+
expectHeaders: map[string]string{},
327+
expectBodyContains: []byte{},
328+
},
329+
{
330+
testDescription: "basic propagation",
331+
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
332+
return &options.Response{
333+
StatusCode: 418,
334+
Headers: map[string]string{},
335+
Body: []byte("badness"),
336+
}
337+
},
338+
expectStatusCode: http.StatusTeapot,
339+
expectHeaders: map[string]string{
340+
"Content-Type": "application/octet-stream",
341+
},
342+
expectBodyContains: []byte("bad"),
343+
},
344+
{
345+
testDescription: "additional header",
346+
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
347+
return &options.Response{
348+
StatusCode: 418,
349+
Headers: map[string]string{"some": "header"},
350+
Body: []byte("badness"),
351+
}
352+
},
353+
expectStatusCode: http.StatusTeapot,
354+
expectHeaders: map[string]string{
355+
"Some": "header",
356+
"Content-Type": "application/octet-stream",
357+
},
358+
expectBodyContains: []byte{},
359+
},
360+
{
361+
testDescription: "content type",
362+
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
363+
return &options.Response{
364+
StatusCode: 418,
365+
Headers: map[string]string{"content-type": "application/json"},
366+
Body: []byte("{}"),
367+
}
368+
},
369+
expectStatusCode: http.StatusTeapot,
370+
expectHeaders: map[string]string{
371+
"Content-Type": "application/json",
372+
},
373+
expectBodyContains: []byte("{}"),
374+
},
345375
}
376+
for i := range cases {
377+
c := cases[i]
378+
t.Logf("Test iteration %d: %s", i, c.testDescription)
379+
opts := []options.Option{
380+
options.WithIssuer(op.GetURL(t)),
381+
options.WithRequiredAudience("test-client"),
382+
options.WithRequiredTokenType("JWT+AT"),
383+
options.WithErrorHandler(c.errorHandler),
384+
}
346385

347-
oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
348-
require.NoError(t, err)
349-
350-
handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)
351-
352-
// Test without token
353-
reqNoAuth := httptest.NewRequest(http.MethodGet, "/", nil)
354-
recNoAuth := httptest.NewRecorder()
355-
handler.ServeHTTP(recNoAuth, reqNoAuth)
356-
357-
require.Equal(t, http.StatusBadRequest, recNoAuth.Result().StatusCode)
386+
oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
387+
require.NoError(t, err)
358388

359-
d, e := getInfo()
389+
handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)
360390

361-
if !strings.Contains(t.Name(), "OidcEchoJwt") {
362-
require.Equal(t, options.GetTokenErrorDescription, d)
363-
require.EqualError(t, e, "unable to extract token: Authorization header empty")
391+
req := httptest.NewRequest(http.MethodGet, "/", nil)
392+
res := httptest.NewRecorder()
393+
handler.ServeHTTP(res, req)
394+
require.Equal(t, c.expectStatusCode, res.Result().StatusCode)
395+
for k, v := range c.expectHeaders {
396+
require.Equal(t, []string{v}, res.Result().Header[k])
397+
}
398+
require.Subset(t, res.Body.Bytes(), c.expectBodyContains)
364399
}
365-
366-
// Test with fake token
367-
token := op.GetToken(t)
368-
token.AccessToken = "foobar"
369-
testHttpWithAuthenticationFailure(t, token, handler)
370-
371-
d, e = getInfo()
372-
373-
require.Equal(t, options.ParseTokenErrorDescription, d)
374-
require.EqualError(t, e, "unable to parse token signature: invalid compact serialization format: invalid number of segments")
375400
})
376401
}
377402

oidcecho/echo.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,29 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
1919
return toEchoMiddleware(h.ParseToken, setters...)
2020
}
2121

22-
func onError(errorHandler options.ErrorHandler, description options.ErrorDescription, err error) {
23-
if errorHandler != nil {
24-
errorHandler(description, err)
22+
func onError(c echo.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
23+
if errorHandler == nil {
24+
c.Logger().Error(err)
25+
return c.NoContent(statusCode)
2526
}
27+
oidcErr := options.OidcError{
28+
Url: c.Request().URL,
29+
Headers: c.Request().Header,
30+
Status: description,
31+
Error: err,
32+
}
33+
response := errorHandler(c.Request().Context(), &oidcErr)
34+
if response == nil {
35+
c.Logger().Error(err)
36+
return c.NoContent(statusCode)
37+
}
38+
for k, v := range response.Headers {
39+
c.Response().Header().Set(k, v)
40+
}
41+
c.Response().Header().Set(echo.HeaderContentType, response.ContentType())
42+
c.Response().WriteHeader(response.StatusCode)
43+
_, err = c.Response().Write(response.Body)
44+
return err
2645
}
2746

2847
func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) echo.MiddlewareFunc {
@@ -34,14 +53,12 @@ func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...optio
3453

3554
tokenString, err := oidc.GetTokenString(c.Request().Header.Get, opts.TokenString)
3655
if err != nil {
37-
onError(opts.ErrorHandler, options.GetTokenErrorDescription, err)
38-
return echo.ErrBadRequest
56+
return onError(c, opts.ErrorHandler, echo.ErrBadRequest.Code, options.GetTokenErrorDescription, err)
3957
}
4058

4159
claims, err := parseToken(ctx, tokenString)
4260
if err != nil {
43-
onError(opts.ErrorHandler, options.ParseTokenErrorDescription, err)
44-
return echo.ErrUnauthorized
61+
return onError(c, opts.ErrorHandler, echo.ErrUnauthorized.Code, options.ParseTokenErrorDescription, err)
4562
}
4663
c.Set(string(opts.ClaimsContextKeyName), claims)
4764
return next(c)

oidcfiber/fiber.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package oidcfiber
22

33
import (
44
"fmt"
5+
"net/url"
56

67
"github.com/gofiber/fiber/v2"
78
"github.com/xenitab/go-oidc-middleware/internal/oidc"
@@ -20,11 +21,29 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
2021
}
2122

2223
func onError(c *fiber.Ctx, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
23-
if errorHandler != nil {
24-
errorHandler(description, err)
24+
if errorHandler == nil {
25+
return c.SendStatus(statusCode)
2526
}
26-
27-
return c.SendStatus(statusCode)
27+
url, _ := url.Parse(c.OriginalURL())
28+
headers := make(map[string][]string, 1)
29+
for k, v := range c.GetReqHeaders() {
30+
headers[k] = []string{v}
31+
}
32+
oidcErr := options.OidcError{
33+
Url: url,
34+
Headers: headers,
35+
Status: description,
36+
Error: err,
37+
}
38+
response := errorHandler(c.Context(), &oidcErr)
39+
if response == nil {
40+
return c.SendStatus(statusCode)
41+
}
42+
for k, v := range response.Headers {
43+
c.Response().Header.Set(k, v)
44+
}
45+
c.Set("Content-Type", response.ContentType())
46+
return c.Status(response.StatusCode).Send(response.Body)
2847
}
2948

3049
func toFiberHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) fiber.Handler {

oidcgin/gin.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,26 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
2020
return toGinHandler(oidcHandler.ParseToken, setters...)
2121
}
2222

23-
func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
24-
if errorHandler != nil {
25-
errorHandler(description, err)
23+
func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
24+
if errorHandler == nil {
25+
return c.AbortWithError(statusCode, err)
2626
}
2727

28-
//nolint:errcheck // false positive
29-
c.AbortWithError(statusCode, err)
28+
oidcErr := options.OidcError{
29+
Url: c.Request.URL,
30+
Headers: c.Request.Header,
31+
Status: description,
32+
Error: err,
33+
}
34+
response := errorHandler(c.Request.Context(), &oidcErr)
35+
if response == nil {
36+
return c.AbortWithError(statusCode, err)
37+
}
38+
for k, v := range response.Headers {
39+
c.Header(k, v)
40+
}
41+
c.Data(response.StatusCode, response.ContentType(), response.Body)
42+
return nil
3043
}
3144

3245
func toGinHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) gin.HandlerFunc {

oidchttp/http.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,28 @@ func New[T any](h http.Handler, claimsValidationFn options.ClaimsValidationFn[T]
2020
return toHttpHandler(h, oidcHandler.ParseToken, setters...)
2121
}
2222

23-
func onError(w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
24-
if errorHandler != nil {
25-
errorHandler(description, err)
23+
func onError(r *http.Request, w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
24+
if errorHandler == nil {
25+
w.WriteHeader(statusCode)
26+
return
2627
}
27-
28-
w.WriteHeader(statusCode)
28+
oidcErr := options.OidcError{
29+
Url: r.URL,
30+
Headers: r.Header,
31+
Status: description,
32+
Error: err,
33+
}
34+
response := errorHandler(r.Context(), &oidcErr)
35+
if response == nil {
36+
w.WriteHeader(statusCode)
37+
return
38+
}
39+
for k, v := range response.Headers {
40+
w.Header().Add(k, v)
41+
}
42+
w.Header().Set("Content-Type", response.ContentType())
43+
w.WriteHeader(response.StatusCode)
44+
w.Write(response.Body)
2945
}
3046

3147
func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], setters ...options.Option) http.Handler {
@@ -36,13 +52,13 @@ func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], set
3652

3753
tokenString, err := oidc.GetTokenString(r.Header.Get, opts.TokenString)
3854
if err != nil {
39-
onError(w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
55+
onError(r, w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
4056
return
4157
}
4258

4359
claims, err := parseToken(ctx, tokenString)
4460
if err != nil {
45-
onError(w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
61+
onError(r, w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
4662
return
4763
}
4864

0 commit comments

Comments
 (0)