Skip to content

Commit ffdc4de

Browse files
author
Anders Qvist
committed
Pass context and request to error handler.
1 parent 27075a1 commit ffdc4de

File tree

6 files changed

+56
-12
lines changed

6 files changed

+56
-12
lines changed

internal/oidctesting/tests.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package oidctesting
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/http/httptest"
@@ -320,14 +321,14 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
320321
}{
321322
{
322323
testDescription: "no output",
323-
errorHandler: func(desc options.ErrorDescription, err error) *options.Response { return nil },
324+
errorHandler: func(ctx context.Context, request *options.OidcError) *options.Response { return nil },
324325
expectStatusCode: http.StatusBadRequest,
325326
expectHeaders: map[string]string{},
326327
expectBodyContains: []byte{},
327328
},
328329
{
329330
testDescription: "basic propagation",
330-
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
331+
errorHandler: func(ctx context.Context, request *options.OidcError) *options.Response {
331332
return &options.Response{
332333
StatusCode: 418,
333334
Headers: map[string]string{},
@@ -342,7 +343,7 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
342343
},
343344
{
344345
testDescription: "additional header",
345-
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
346+
errorHandler: func(ctx context.Context, request *options.OidcError) *options.Response {
346347
return &options.Response{
347348
StatusCode: 418,
348349
Headers: map[string]string{"some": "header"},
@@ -358,7 +359,7 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
358359
},
359360
{
360361
testDescription: "content type",
361-
errorHandler: func(desc options.ErrorDescription, err error) *options.Response {
362+
errorHandler: func(ctx context.Context, request *options.OidcError) *options.Response {
362363
return &options.Response{
363364
StatusCode: 418,
364365
Headers: map[string]string{"content-type": "application/json"},

oidcecho/echo.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ func onError(c echo.Context, errorHandler options.ErrorHandler, statusCode int,
2424
c.Logger().Error(err)
2525
return c.NoContent(statusCode)
2626
}
27-
response := errorHandler(description, err)
27+
error := options.OidcError{
28+
Url: c.Request().URL,
29+
Headers: c.Request().Header,
30+
Status: description,
31+
Error: err,
32+
}
33+
response := errorHandler(c.Request().Context(), &error)
2834
if response == nil {
2935
c.Logger().Error(err)
3036
return c.NoContent(statusCode)

oidcfiber/fiber.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package oidcfiber
22

33
import (
44
"fmt"
5+
"net/url"
56

67
"github.com/gofiber/fiber/v2"
78
"github.com/xenitab/go-oidc-middleware/internal/oidc"
@@ -23,7 +24,18 @@ func onError(c *fiber.Ctx, errorHandler options.ErrorHandler, statusCode int, de
2324
if errorHandler == nil {
2425
return c.SendStatus(statusCode)
2526
}
26-
response := errorHandler(description, err)
27+
url, _ := url.Parse(c.OriginalURL())
28+
headers := make(map[string][]string, 1)
29+
for k, v := range c.GetReqHeaders() {
30+
headers[k] = []string{v}
31+
}
32+
error := options.OidcError{
33+
Url: url,
34+
Headers: headers,
35+
Status: description,
36+
Error: err,
37+
}
38+
response := errorHandler(c.Context(), &error)
2739
if response == nil {
2840
return c.SendStatus(statusCode)
2941
}

oidcgin/gin.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int,
2424
if errorHandler == nil {
2525
return c.AbortWithError(statusCode, err)
2626
}
27-
response := errorHandler(description, err)
27+
28+
error := options.OidcError{
29+
Url: c.Request.URL,
30+
Headers: c.Request.Header,
31+
Status: description,
32+
Error: err,
33+
}
34+
response := errorHandler(c.Request.Context(), &error)
2835
if response == nil {
2936
return c.AbortWithError(statusCode, err)
3037
}

oidchttp/http.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@ func New[T any](h http.Handler, claimsValidationFn options.ClaimsValidationFn[T]
2020
return toHttpHandler(h, oidcHandler.ParseToken, setters...)
2121
}
2222

23-
func onError(w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
23+
func onError(r *http.Request, w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
2424
if errorHandler == nil {
2525
w.WriteHeader(statusCode)
2626
return
2727
}
28-
response := errorHandler(description, err)
28+
error := options.OidcError{
29+
Url: r.URL,
30+
Headers: r.Header,
31+
Status: description,
32+
Error: err,
33+
}
34+
response := errorHandler(r.Context(), &error)
2935
if response == nil {
3036
w.WriteHeader(statusCode)
3137
return
@@ -46,13 +52,13 @@ func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], set
4652

4753
tokenString, err := oidc.GetTokenString(r.Header.Get, opts.TokenString)
4854
if err != nil {
49-
onError(w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
55+
onError(r, w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
5056
return
5157
}
5258

5359
claims, err := parseToken(ctx, tokenString)
5460
if err != nil {
55-
onError(w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
61+
onError(r, w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
5662
return
5763
}
5864

options/options.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
package options
22

33
import (
4+
"context"
45
"net/http"
6+
"net/url"
57
"time"
68
)
79

10+
// Context information for the error handler.
11+
type OidcError struct {
12+
Url *url.URL
13+
Headers http.Header
14+
Error error
15+
Status ErrorDescription
16+
}
17+
18+
// Error handlers are expected to produce an abstract HTTP response that
19+
// the framework adapter will render.
820
type Response struct {
921
StatusCode int
1022
Headers map[string]string
@@ -36,7 +48,7 @@ type ClaimsContextKeyName string
3648
const DefaultClaimsContextKeyName ClaimsContextKeyName = "claims"
3749

3850
// ErrorHandler is called by the middleware if not nil
39-
type ErrorHandler func(description ErrorDescription, err error) *Response
51+
type ErrorHandler func(ctx context.Context, request *OidcError) *Response
4052

4153
// ErrorDescription is used to pass the description of the error to ErrorHandler
4254
type ErrorDescription string

0 commit comments

Comments
 (0)