diff --git a/bench_test.go b/bench_test.go index a4440bc1c..da78d65c8 100644 --- a/bench_test.go +++ b/bench_test.go @@ -163,6 +163,41 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) { } } +func BenchmarkMinimalPgConnPreparedStatementDescriptionSelect(b *testing.B) { + conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) + defer closeConn(b, conn) + + pgConn := conn.PgConn() + + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil) + if err != nil { + b.Fatal(err) + } + + encodedBytes := make([]byte, 8) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + + rr := pgConn.ExecPreparedStatementDescription(context.Background(), psd, [][]byte{encodedBytes}, []int16{1}, []int16{1}) + if err != nil { + b.Fatal(err) + } + + for rr.NextRow() { + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[0], encodedBytes) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes) + } + } + } + _, err = rr.Close() + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) @@ -1282,6 +1317,51 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } } +func BenchmarkSelectRowsPgConnExecPreparedStatementDescription(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + psd, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + if err != nil { + b.Fatal(err) + } + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + formats := []struct { + name string + code int16 + }{ + {"text", pgx.TextFormatCode}, + {"binary - mostly", pgx.BinaryFormatCode}, + } + for _, format := range formats { + b.Run(format.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + rr := conn.PgConn().ExecPreparedStatementDescription( + context.Background(), + psd, + [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, + nil, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, + ) + for rr.NextRow() { + rr.Values() + } + + _, err := rr.Close() + if err != nil { + b.Fatal(err) + } + } + }) + } + }) + } +} + type queryRecorder struct { conn net.Conn writeBuf []byte diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 5c77a836e..84133f292 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -22,6 +22,7 @@ import ( "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" ) const ( @@ -1165,7 +1166,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) return result } @@ -1190,7 +1191,37 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(result) + pgConn.execExtendedSuffix(result, nil, nil) + + return result +} + +// ExecPreparedStatementDescription enqueues the execution of a prepared statement via the PostgreSQL extended query +// protocol. +// +// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name. +// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get +// the result column descriptions. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if len(paramFormats) is not +// 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or binary +// format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPreparedStatementDescription(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + + pgConn.execExtendedSuffix(result, statementDescription, resultFormats) return result } @@ -1230,8 +1261,10 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader, statementDescription *StatementDescription, resultFormats []int16) { + if statementDescription == nil { + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + } pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) @@ -1245,7 +1278,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { return } - result.readUntilRowDescription() + result.readUntilRowDescription(statementDescription, resultFormats) } // CopyTo executes the copy command sql and copies the results to w. @@ -1662,13 +1695,36 @@ func (rr *ResultReader) Close() (CommandTag, error) { // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any // error will be stored in the ResultReader. -func (rr *ResultReader) readUntilRowDescription() { +func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) { for !rr.commandConcluded { // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are // manually used to construct a query that does not issue a describe statement. msg, _ := rr.pgConn.peekMessage() if _, ok := msg.(*pgproto3.DataRow); ok { + if statementDescription != nil { + rr.fieldDescriptions = statementDescription.Fields + // Adjust field descriptions for resultFormats + if len(resultFormats) == 0 { + // No format codes provided, default to text format + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = pgtype.TextFormatCode + } + } else if len(resultFormats) == 1 { + // Single format code applies to all columns + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = resultFormats[0] + } + } else if len(resultFormats) == len(rr.fieldDescriptions) { + // One format code per column + for i := range rr.fieldDescriptions { + rr.fieldDescriptions[i].Format = resultFormats[i] + } + } else { + // This should be impossible to reach as the mismatch would have been caught earlier. + rr.concludeCommand(CommandTag{}, fmt.Errorf("mismatched result format codes length")) + } + } return } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 3663bc904..e22f81c48 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1439,6 +1439,146 @@ func TestConnExecPreparedEmptySQL(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedStatementDescription(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPreparedStatementDescription(ctx, psd, [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, "msg", result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", commandTag.String()) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +type byteCounterConn struct { + conn net.Conn + bytesRead int + bytesWritten int +} + +func (cbn *byteCounterConn) Read(b []byte) (n int, err error) { + n, err = cbn.conn.Read(b) + cbn.bytesRead += n + return n, err +} + +func (cbn *byteCounterConn) Write(b []byte) (n int, err error) { + n, err = cbn.conn.Write(b) + cbn.bytesWritten += n + return n, err +} + +func (cbn *byteCounterConn) Close() error { + return cbn.conn.Close() +} + +func (cbn *byteCounterConn) LocalAddr() net.Addr { + return cbn.conn.LocalAddr() +} + +func (cbn *byteCounterConn) RemoteAddr() net.Addr { + return cbn.conn.RemoteAddr() +} + +func (cbn *byteCounterConn) SetDeadline(t time.Time) error { + return cbn.conn.SetDeadline(t) +} + +func (cbn *byteCounterConn) SetReadDeadline(t time.Time) error { + return cbn.conn.SetReadDeadline(t) +} + +func (cbn *byteCounterConn) SetWriteDeadline(t time.Time) error { + return cbn.conn.SetWriteDeadline(t) +} + +func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var counterConn *byteCounterConn + config.AfterNetConnect = func(ctx context.Context, config *pgconn.Config, conn net.Conn) (net.Conn, error) { + counterConn = &byteCounterConn{conn: conn} + return counterConn, nil + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + require.NotNil(t, counterConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server uses different number of bytes for same operations") + } + + psd, err := pgConn.Prepare(ctx, "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 9) + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + result := pgConn.ExecPrepared(ctx, + psd.Name, + [][]byte{[]byte("1")}, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode}, + ).Read() + require.NoError(t, result.Err) + withDescribeBytesWritten := counterConn.bytesWritten + withDescribeBytesRead := counterConn.bytesRead + + counterConn.bytesWritten = 0 + counterConn.bytesRead = 0 + + result = pgConn.ExecPreparedStatementDescription( + ctx, + psd, + [][]byte{[]byte("1")}, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode}, + ).Read() + require.NoError(t, result.Err) + noDescribeBytesWritten := counterConn.bytesWritten + noDescribeBytesRead := counterConn.bytesRead + + assert.Equal(t, 61, withDescribeBytesWritten) + assert.Equal(t, 54, noDescribeBytesWritten) + assert.Equal(t, 391, withDescribeBytesRead) + assert.Equal(t, 153, noDescribeBytesRead) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel()