@@ -81,6 +81,10 @@ type Conn struct {
8181}
8282
8383func (conn * Conn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
84+ return conn .prepareContext (ctx , query )
85+ }
86+
87+ func (conn * Conn ) prepareContext (ctx context.Context , query string ) (* Stmt , error ) {
8488 var (
8589 stmt driver.Stmt
8690 err error
@@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
9397 }
9498
9599 if err != nil {
96- return stmt , err
100+ return nil , err
97101 }
98102
99103 return & Stmt {stmt , conn .hooks , query }, nil
@@ -139,21 +143,39 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [
139143}
140144
141145func (conn * ExecerContext ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
146+ return execWithHooks (ctx , query , args , conn .hooks , func (ctx context.Context ) (driver.Result , error ) {
147+ results , err := conn .execContext (ctx , query , args )
148+ if err == nil || ! errors .Is (err , driver .ErrSkip ) {
149+ return results , err
150+ }
151+ // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
152+ // We need to avoid executing the hooks twice since they were already run in ExecContext.
153+ // This matches the behavior in database/sql when ExecContext returns ErrSkip.
154+ stmt , err := conn .prepareContext (ctx , query )
155+ if err != nil {
156+ return nil , err
157+ }
158+ defer stmt .Close ()
159+ return stmt .execContext (ctx , args )
160+ })
161+ }
162+
163+ func execWithHooks (ctx context.Context , query string , args []driver.NamedValue , hooks Hooks , execer func (context.Context ) (driver.Result , error )) (driver.Result , error ) {
142164 var err error
143165
144166 list := namedToInterface (args )
145167
146168 // Exec `Before` Hooks
147- if ctx , err = conn . hooks .Before (ctx , query , list ... ); err != nil {
169+ if ctx , err = hooks .Before (ctx , query , list ... ); err != nil {
148170 return nil , err
149171 }
150172
151- results , err := conn . execContext (ctx , query , args )
173+ results , err := execer (ctx )
152174 if err != nil {
153- return results , handlerErr (ctx , conn . hooks , err , query , list ... )
175+ return results , handlerErr (ctx , hooks , err , query , list ... )
154176 }
155177
156- if _ , err := conn . hooks .After (ctx , query , list ... ); err != nil {
178+ if _ , err := hooks .After (ctx , query , list ... ); err != nil {
157179 return nil , err
158180 }
159181
@@ -201,21 +223,43 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args
201223}
202224
203225func (conn * QueryerContext ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
226+ return queryWithHooks (ctx , query , args , conn .hooks , func (ctx context.Context ) (driver.Rows , error ) {
227+ rows , err := conn .queryContext (ctx , query , args )
228+ if err == nil || ! errors .Is (err , driver .ErrSkip ) {
229+ return rows , err
230+ }
231+ // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
232+ // We need to avoid executing the hooks twice since they were already run in QueryContext.
233+ // This matches the behavior in database/sql when QueryContext returns ErrSkip.
234+ stmt , err := conn .prepareContext (ctx , query )
235+ if err != nil {
236+ return nil , err
237+ }
238+ rows , err = stmt .queryContext (ctx , args )
239+ if err != nil {
240+ _ = stmt .Close ()
241+ return nil , err
242+ }
243+ return & rowsWrapper {rows : rows , closeStmt : stmt }, nil
244+ })
245+ }
246+
247+ func queryWithHooks (ctx context.Context , query string , args []driver.NamedValue , hooks Hooks , queryer func (context.Context ) (driver.Rows , error )) (driver.Rows , error ) {
204248 var err error
205249
206250 list := namedToInterface (args )
207251
208252 // Query `Before` Hooks
209- if ctx , err = conn . hooks .Before (ctx , query , list ... ); err != nil {
253+ if ctx , err = hooks .Before (ctx , query , list ... ); err != nil {
210254 return nil , err
211255 }
212256
213- results , err := conn . queryContext (ctx , query , args )
257+ results , err := queryer (ctx )
214258 if err != nil {
215- return results , handlerErr (ctx , conn . hooks , err , query , list ... )
259+ return results , handlerErr (ctx , hooks , err , query , list ... )
216260 }
217261
218- if _ , err := conn . hooks .After (ctx , query , list ... ); err != nil {
262+ if _ , err := hooks .After (ctx , query , list ... ); err != nil {
219263 return nil , err
220264 }
221265
@@ -264,25 +308,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr
264308}
265309
266310func (stmt * Stmt ) ExecContext (ctx context.Context , args []driver.NamedValue ) (driver.Result , error ) {
267- var err error
268-
269- list := namedToInterface (args )
270-
271- // Exec `Before` Hooks
272- if ctx , err = stmt .hooks .Before (ctx , stmt .query , list ... ); err != nil {
273- return nil , err
274- }
275-
276- results , err := stmt .execContext (ctx , args )
277- if err != nil {
278- return results , handlerErr (ctx , stmt .hooks , err , stmt .query , list ... )
279- }
280-
281- if _ , err := stmt .hooks .After (ctx , stmt .query , list ... ); err != nil {
282- return nil , err
283- }
284-
285- return results , err
311+ return execWithHooks (ctx , stmt .query , args , stmt .hooks , func (ctx context.Context ) (driver.Result , error ) {
312+ return stmt .execContext (ctx , args )
313+ })
286314}
287315
288316func (stmt * Stmt ) queryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
@@ -298,25 +326,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d
298326}
299327
300328func (stmt * Stmt ) QueryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
301- var err error
302-
303- list := namedToInterface (args )
304-
305- // Exec Before Hooks
306- if ctx , err = stmt .hooks .Before (ctx , stmt .query , list ... ); err != nil {
307- return nil , err
308- }
309-
310- rows , err := stmt .queryContext (ctx , args )
311- if err != nil {
312- return rows , handlerErr (ctx , stmt .hooks , err , stmt .query , list ... )
313- }
314-
315- if _ , err := stmt .hooks .After (ctx , stmt .query , list ... ); err != nil {
316- return nil , err
317- }
318-
319- return rows , err
329+ return queryWithHooks (ctx , stmt .query , args , stmt .hooks , func (ctx context.Context ) (driver.Rows , error ) {
330+ return stmt .queryContext (ctx , args )
331+ })
320332}
321333
322334func (stmt * Stmt ) Close () error { return stmt .Stmt .Close () }
@@ -350,6 +362,27 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
350362 return dargs , nil
351363}
352364
365+ type rowsWrapper struct {
366+ rows driver.Rows
367+ closeStmt driver.Stmt // if non-nil, statement to Close on close
368+ }
369+
370+ func (r * rowsWrapper ) Close () error {
371+ err := r .rows .Close ()
372+ if r .closeStmt != nil {
373+ _ = r .closeStmt .Close ()
374+ }
375+ return err
376+ }
377+
378+ func (r * rowsWrapper ) Columns () []string {
379+ return r .rows .Columns ()
380+ }
381+
382+ func (r * rowsWrapper ) Next (dest []driver.Value ) error {
383+ return r .rows .Next (dest )
384+ }
385+
353386/*
354387type hooks struct {
355388}
0 commit comments