77 "time"
88
99 "github.com/go-redis/redis/v7"
10- "github.com/imdario/mergo"
1110 "github.com/labstack/echo/v4"
1211 "github.com/labstack/echo/v4/middleware"
1312 "github.com/shareed2k/go_limiter"
@@ -67,6 +66,10 @@ type (
6766 // Default:
6867 Prefix string
6968
69+ // SkipOnError
70+ // Default: false
71+ SkipOnError bool
72+
7073 // Period
7174 Period time.Duration
7275
@@ -81,6 +84,12 @@ type (
8184 // return ctx.String(defaultStatusCode, defaultMessage)
8285 // }
8386 Handler func (echo.Context ) error
87+
88+ // ErrHandler is called when a error happen inside go_limiiter lib
89+ // Default: func(c echo.Context) {
90+ // return ctx.String(defaultStatusCode, defaultMessage)
91+ // }
92+ ErrHandler func (error , echo.Context ) error
8493 }
8594)
8695
@@ -91,20 +100,58 @@ func New(rediser *redis.Client) echo.MiddlewareFunc {
91100}
92101
93102func NewWithConfig (config Config ) echo.MiddlewareFunc {
94- if err := mergo .Merge (& config , DefaultConfig ); err != nil {
95- panic (err )
96- }
97-
98103 if config .Rediser == nil {
99104 panic (errors .New ("redis client is missing" ))
100105 }
101106
107+ if config .Skipper == nil {
108+ config .Skipper = DefaultConfig .Skipper
109+ }
110+
111+ if config .Max == 0 {
112+ config .Max = DefaultConfig .Max
113+ }
114+
115+ if config .Burst == 0 {
116+ config .Burst = DefaultConfig .Burst
117+ }
118+
119+ if config .StatusCode == 0 {
120+ config .StatusCode = DefaultConfig .StatusCode
121+ }
122+
123+ if config .Message == "" {
124+ config .Message = DefaultConfig .Message
125+ }
126+
127+ if config .Algorithm == "" {
128+ config .Algorithm = DefaultConfig .Algorithm
129+ }
130+
131+ if config .Prefix == "" {
132+ config .Prefix = DefaultConfig .Prefix
133+ }
134+
135+ if config .Period == 0 {
136+ config .Period = DefaultConfig .Period
137+ }
138+
139+ if config .Key == nil {
140+ config .Key = DefaultConfig .Key
141+ }
142+
102143 if config .Handler == nil {
103144 config .Handler = func (ctx echo.Context ) error {
104145 return ctx .String (config .StatusCode , config .Message )
105146 }
106147 }
107148
149+ if config .ErrHandler == nil {
150+ config .ErrHandler = func (err error , ctx echo.Context ) error {
151+ return echo .NewHTTPError (http .StatusInternalServerError , err )
152+ }
153+ }
154+
108155 limiter := go_limiter .NewLimiter (config .Rediser )
109156 limit := & go_limiter.Limit {
110157 Period : config .Period ,
@@ -123,21 +170,23 @@ func NewWithConfig(config Config) echo.MiddlewareFunc {
123170 if err != nil {
124171 ctx .Logger ().Error (err )
125172
126- return next (ctx )
173+ if config .SkipOnError {
174+ return next (ctx )
175+ }
176+
177+ return config .ErrHandler (err , ctx )
127178 }
128179
129180 res := ctx .Response ()
130181
131182 // Check if hits exceed the max
132183 if ! result .Allowed {
133- // Call Handler func
134- err := config .Handler (ctx )
135-
136184 // Return response with Retry-After header
137185 // https://tools.ietf.org/html/rfc6584
138186 res .Header ().Set ("Retry-After" , strconv .FormatInt (time .Now ().Add (result .RetryAfter ).Unix (), 10 ))
139187
140- return err
188+ // Call Handler func
189+ return config .Handler (ctx )
141190 }
142191
143192 // We can continue, update RateLimit headers
0 commit comments