Skip to content

Commit 3a9a71a

Browse files
committed
fix: handle driver.ErrSkip to avoid duplicate hooks execution with MySQL driver
When InterpolateParams=false is set in MySQL driver, it returns driver.ErrSkip which causes the SQL package to fall back to prepared statements, resulting in hooks being executed twice. This change handles driver.ErrSkip internally to ensure hooks are only executed once per logical operation.
1 parent 7875602 commit 3a9a71a

File tree

2 files changed

+100
-67
lines changed

2 files changed

+100
-67
lines changed

sqlhooks.go

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ type Conn struct {
8181
}
8282

8383
func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
84+
return conn.prepareContext(ctx, query)
85+
}
86+
87+
func (conn *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
8488
var (
8589
stmt driver.Stmt
8690
err error
@@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
9397
}
9498

9599
if err != nil {
96-
return stmt, err
100+
return nil, err
97101
}
98102

99103
return &Stmt{stmt, conn.hooks, query}, nil
@@ -139,21 +143,39 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [
139143
}
140144

141145
func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
146+
return execWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Result, error) {
147+
results, err := conn.execContext(ctx, query, args)
148+
if err == nil || !errors.Is(err, driver.ErrSkip) {
149+
return results, err
150+
}
151+
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
152+
// We need to avoid executing the hooks twice since they were already run in ExecContext.
153+
// This matches the behavior in database/sql when ExecContext returns ErrSkip.
154+
stmt, err := conn.prepareContext(ctx, query)
155+
if err != nil {
156+
return nil, err
157+
}
158+
defer stmt.Close()
159+
return stmt.execContext(ctx, args)
160+
})
161+
}
162+
163+
func execWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, execer func(context.Context) (driver.Result, error)) (driver.Result, error) {
142164
var err error
143165

144166
list := namedToInterface(args)
145167

146168
// Exec `Before` Hooks
147-
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
169+
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
148170
return nil, err
149171
}
150172

151-
results, err := conn.execContext(ctx, query, args)
173+
results, err := execer(ctx)
152174
if err != nil {
153-
return results, handlerErr(ctx, conn.hooks, err, query, list...)
175+
return results, handlerErr(ctx, hooks, err, query, list...)
154176
}
155177

156-
if _, err := conn.hooks.After(ctx, query, list...); err != nil {
178+
if _, err := hooks.After(ctx, query, list...); err != nil {
157179
return nil, err
158180
}
159181

@@ -201,21 +223,43 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args
201223
}
202224

203225
func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
226+
return queryWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Rows, error) {
227+
rows, err := conn.queryContext(ctx, query, args)
228+
if err == nil || !errors.Is(err, driver.ErrSkip) {
229+
return rows, err
230+
}
231+
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
232+
// We need to avoid executing the hooks twice since they were already run in QueryContext.
233+
// This matches the behavior in database/sql when QueryContext returns ErrSkip.
234+
stmt, err := conn.prepareContext(ctx, query)
235+
if err != nil {
236+
return nil, err
237+
}
238+
rows, err = stmt.queryContext(ctx, args)
239+
if err != nil {
240+
_ = stmt.Close()
241+
return nil, err
242+
}
243+
return &rowsWrapper{rows: rows, closeStmt: stmt}, nil
244+
})
245+
}
246+
247+
func queryWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, queryer func(context.Context) (driver.Rows, error)) (driver.Rows, error) {
204248
var err error
205249

206250
list := namedToInterface(args)
207251

208252
// Query `Before` Hooks
209-
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
253+
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
210254
return nil, err
211255
}
212256

213-
results, err := conn.queryContext(ctx, query, args)
257+
results, err := queryer(ctx)
214258
if err != nil {
215-
return results, handlerErr(ctx, conn.hooks, err, query, list...)
259+
return results, handlerErr(ctx, hooks, err, query, list...)
216260
}
217261

218-
if _, err := conn.hooks.After(ctx, query, list...); err != nil {
262+
if _, err := hooks.After(ctx, query, list...); err != nil {
219263
return nil, err
220264
}
221265

@@ -264,25 +308,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr
264308
}
265309

266310
func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
267-
var err error
268-
269-
list := namedToInterface(args)
270-
271-
// Exec `Before` Hooks
272-
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
273-
return nil, err
274-
}
275-
276-
results, err := stmt.execContext(ctx, args)
277-
if err != nil {
278-
return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
279-
}
280-
281-
if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
282-
return nil, err
283-
}
284-
285-
return results, err
311+
return execWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Result, error) {
312+
return stmt.execContext(ctx, args)
313+
})
286314
}
287315

288316
func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -298,25 +326,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d
298326
}
299327

300328
func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
301-
var err error
302-
303-
list := namedToInterface(args)
304-
305-
// Exec Before Hooks
306-
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
307-
return nil, err
308-
}
309-
310-
rows, err := stmt.queryContext(ctx, args)
311-
if err != nil {
312-
return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
313-
}
314-
315-
if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
316-
return nil, err
317-
}
318-
319-
return rows, err
329+
return queryWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Rows, error) {
330+
return stmt.queryContext(ctx, args)
331+
})
320332
}
321333

322334
func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }
@@ -350,6 +362,27 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
350362
return dargs, nil
351363
}
352364

365+
type rowsWrapper struct {
366+
rows driver.Rows
367+
closeStmt driver.Stmt // if non-nil, statement to Close on close
368+
}
369+
370+
func (r *rowsWrapper) Close() error {
371+
err := r.rows.Close()
372+
if r.closeStmt != nil {
373+
_ = r.closeStmt.Close()
374+
}
375+
return err
376+
}
377+
378+
func (r *rowsWrapper) Columns() []string {
379+
return r.rows.Columns()
380+
}
381+
382+
func (r *rowsWrapper) Next(dest []driver.Value) error {
383+
return r.rows.Next(dest)
384+
}
385+
353386
/*
354387
type hooks struct {
355388
}

sqlhooks_test.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,62 +68,62 @@ func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
6868
}
6969

7070
func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
71-
var before, after bool
71+
var beforeCount, afterCount int
7272

7373
s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
74-
before = true
74+
beforeCount++
7575
return ctx, nil
7676
}
7777
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
78-
after = true
78+
afterCount++
7979
return ctx, nil
8080
}
8181

8282
t.Run("Query", func(t *testing.T) {
83-
before, after = false, false
83+
beforeCount, afterCount = 0, 0
8484
_, err := s.db.Query(query, args...)
8585
require.NoError(t, err)
86-
assert.True(t, before, "Before Hook did not run for query: "+query)
87-
assert.True(t, after, "After Hook did not run for query: "+query)
86+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
87+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
8888
})
8989

9090
t.Run("QueryContext", func(t *testing.T) {
91-
before, after = false, false
91+
beforeCount, afterCount = 0, 0
9292
_, err := s.db.QueryContext(context.Background(), query, args...)
9393
require.NoError(t, err)
94-
assert.True(t, before, "Before Hook did not run for query: "+query)
95-
assert.True(t, after, "After Hook did not run for query: "+query)
94+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
95+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
9696
})
9797

9898
t.Run("Exec", func(t *testing.T) {
99-
before, after = false, false
99+
beforeCount, afterCount = 0, 0
100100
_, err := s.db.Exec(query, args...)
101101
require.NoError(t, err)
102-
assert.True(t, before, "Before Hook did not run for query: "+query)
103-
assert.True(t, after, "After Hook did not run for query: "+query)
102+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
103+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
104104
})
105105

106106
t.Run("ExecContext", func(t *testing.T) {
107-
before, after = false, false
107+
beforeCount, afterCount = 0, 0
108108
_, err := s.db.ExecContext(context.Background(), query, args...)
109109
require.NoError(t, err)
110-
assert.True(t, before, "Before Hook did not run for query: "+query)
111-
assert.True(t, after, "After Hook did not run for query: "+query)
110+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
111+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
112112
})
113113

114114
t.Run("Statements", func(t *testing.T) {
115-
before, after = false, false
115+
beforeCount, afterCount = 0, 0
116116
stmt, err := s.db.Prepare(query)
117117
require.NoError(t, err)
118118

119119
// Hooks just run when the stmt is executed (Query or Exec)
120-
assert.False(t, before, "Before Hook run before execution: "+query)
121-
assert.False(t, after, "After Hook run before execution: "+query)
120+
assert.Equal(t, 0, beforeCount, "Before Hook run before execution: "+query)
121+
assert.Equal(t, 0, afterCount, "After Hook run before execution: "+query)
122122

123123
_, err = stmt.Query(args...)
124124
require.NoError(t, err)
125-
assert.True(t, before, "Before Hook did not run for query: "+query)
126-
assert.True(t, after, "After Hook did not run for query: "+query)
125+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
126+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
127127
})
128128
}
129129

0 commit comments

Comments
 (0)