From f41de3943d8938debcdb5c5ce8c18aeefb38caf4 Mon Sep 17 00:00:00 2001 From: General Kroll Date: Thu, 2 Apr 2026 19:11:53 +1100 Subject: [PATCH 01/11] pg-wire-extended-query-support Summary: - Support for the `postgres` extended query pattern. --- docs/pg_wire_stackql_backend_migration.md | 217 ++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 docs/pg_wire_stackql_backend_migration.md diff --git a/docs/pg_wire_stackql_backend_migration.md b/docs/pg_wire_stackql_backend_migration.md new file mode 100644 index 00000000..95c134d4 --- /dev/null +++ b/docs/pg_wire_stackql_backend_migration.md @@ -0,0 +1,217 @@ +# Migrating a stackql backend to extended query support + +This document describes how to add extended query protocol support to a stackql `ISQLBackend` implementation, replacing the default stubs provided by `psql-wire`. + +## Background + +The `psql-wire` library auto-detects whether an `ISQLBackend` also implements `IExtendedQueryBackend` via a type assertion in `connection.go`: + +```go +if eb, ok := sqlBackend.(sqlbackend.IExtendedQueryBackend); ok { + extBackend = eb +} else if sqlBackend != nil { + extBackend = sqlbackend.NewDefaultExtendedQueryBackend(sqlBackend) +} +``` + +If the backend does not implement `IExtendedQueryBackend`, a `DefaultExtendedQueryBackend` wraps it. This default delegates `HandleExecute` to `HandleSimpleQuery` and stubs out everything else. Client libraries like pgx can connect and run unparameterised queries through this path. + +No factory or wiring changes are needed to opt in. Adding the methods to the existing struct is sufficient. + +## The IExtendedQueryBackend interface + +```go +type IExtendedQueryBackend interface { + HandleParse(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, error) + HandleBind(ctx context.Context, portalName string, stmtName string, paramFormats []int16, paramValues [][]byte, resultFormats []int16) error + HandleDescribeStatement(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, []sqldata.ISQLColumn, error) + HandleDescribePortal(ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32) ([]sqldata.ISQLColumn, error) + HandleExecute(ctx context.Context, portalName string, stmtName string, query string, paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32) (sqldata.ISQLResultStream, error) + HandleCloseStatement(ctx context.Context, stmtName string) error + HandleClosePortal(ctx context.Context, portalName string) error +} +``` + +## Migration steps + +### Step 1: Add no-op stubs to the existing backend + +Locate the struct in stackql that implements `ISQLBackend` (the one with `HandleSimpleQuery`, `SplitCompoundQuery`, and `GetDebugStr`). Add the seven methods below. These are direct copies of the `DefaultExtendedQueryBackend` behaviour, so robot tests should produce identical results. + +```go +func (sb *YourBackend) HandleParse(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, error) { + return paramOIDs, nil +} + +func (sb *YourBackend) HandleBind(ctx context.Context, portalName string, stmtName string, paramFormats []int16, paramValues [][]byte, resultFormats []int16) error { + return nil +} + +func (sb *YourBackend) HandleDescribeStatement(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, []sqldata.ISQLColumn, error) { + return paramOIDs, nil, nil +} + +func (sb *YourBackend) HandleDescribePortal(ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32) ([]sqldata.ISQLColumn, error) { + return nil, nil +} + +func (sb *YourBackend) HandleExecute(ctx context.Context, portalName string, stmtName string, query string, paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32) (sqldata.ISQLResultStream, error) { + return sb.HandleSimpleQuery(ctx, query) +} + +func (sb *YourBackend) HandleCloseStatement(ctx context.Context, stmtName string) error { + return nil +} + +func (sb *YourBackend) HandleClosePortal(ctx context.Context, portalName string) error { + return nil +} +``` + +**Verification**: run the full robot test suite. Behaviour should be unchanged. + +### Step 2: Add a compile-time interface check + +Near the top of the file, add: + +```go +var _ sqlbackend.IExtendedQueryBackend = (*YourBackend)(nil) +``` + +This ensures the compiler catches missing methods if the interface changes. + +### Step 3: Implement HandleExecute with parameter substitution + +`HandleExecute` receives the query string and bound parameter values. The parameters arrive as: + +- `paramFormats []int16` — one entry per parameter: `0` = text, `1` = binary. May be empty (all text) or length 1 (applies to all). +- `paramValues [][]byte` — raw bytes for each parameter. `nil` entry = SQL NULL. +- `resultFormats []int16` — requested result column formats (currently safe to ignore; the wire library encodes as text). + +#### Option A: String interpolation (simplest) + +Replace positional parameters (`$1`, `$2`, ...) with their text values, applying appropriate quoting, then delegate to `HandleSimpleQuery`: + +```go +func (sb *YourBackend) HandleExecute(ctx context.Context, portalName string, stmtName string, query string, paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32) (sqldata.ISQLResultStream, error) { + resolved := substituteParams(query, paramFormats, paramValues) + return sb.HandleSimpleQuery(ctx, resolved) +} +``` + +Where `substituteParams` replaces `$N` tokens with the corresponding text value. Rules: + +- `paramValues[i] == nil` → substitute `NULL` (no quotes). +- Otherwise use the text representation from `paramValues[i]`. Quote string values with single quotes and escape embedded single quotes by doubling them (`'` → `''`). +- If `paramFormats` is empty or has length 1, treat all parameters as that format (usually 0 = text). + +#### Option B: Native parameterisation + +If the stackql execution engine supports parameterised queries, pass the values through directly. This avoids quoting issues and is more correct long-term. + +**Verification**: write a test that connects with pgx, runs a parameterised query, and checks the results: + +```go +rows, err := conn.Query(ctx, "SELECT $1::text, $2::int", "hello", 42) +``` + +### Step 4: Implement HandleDescribeStatement + +Client libraries call Describe after Parse to learn the result column types before any rows arrive. This allows typed scanning (e.g., pgx allocates `int32` vs `string` targets). + +Return parameter OIDs and result column metadata: + +```go +func (sb *YourBackend) HandleDescribeStatement(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, []sqldata.ISQLColumn, error) { + columns, err := sb.planQuery(query) // derive columns from query planner / schema + if err != nil { + return nil, nil, err + } + return paramOIDs, columns, nil +} +``` + +Each `ISQLColumn` requires: +- `GetName()` — column name +- `GetObjectID()` — PostgreSQL type OID (e.g., `25` for text, `23` for int4, `16` for bool) +- `GetWidth()` — column width in bytes (use `-1` if variable) +- `GetTableId()`, `GetAttrNum()` — can be `0` if not applicable +- `GetTypeModifier()` — usually `-1` +- `GetFormat()` — `"text"` for text format + +If the query planner cannot derive columns (e.g., for DDL), return `nil` columns — the wire library sends `NoData`, which is valid. + +**Verification**: use pgx to prepare a statement and check that `FieldDescriptions()` returns the expected column metadata. + +### Step 5: Implement HandleDescribePortal + +Similar to `HandleDescribeStatement`, but for a bound portal. The portal already has its parameters bound, so column metadata may be more precise. In many cases this can delegate to the same logic: + +```go +func (sb *YourBackend) HandleDescribePortal(ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32) ([]sqldata.ISQLColumn, error) { + _, columns, err := sb.HandleDescribeStatement(ctx, stmtName, query, paramOIDs) + return columns, err +} +``` + +### Step 6: Implement HandleParse with type resolution + +If stackql can resolve unspecified parameter types (OID = 0) from the query, do so in `HandleParse`. Otherwise, the current pass-through is fine — clients that send OID 0 will format parameters as text, which works with string interpolation. + +```go +func (sb *YourBackend) HandleParse(ctx context.Context, stmtName string, query string, paramOIDs []uint32) ([]uint32, error) { + resolved, err := sb.resolveParamTypes(query, paramOIDs) + if err != nil { + return nil, err + } + return resolved, nil +} +``` + +### Step 7: Implement HandleBind with validation + +If stackql can validate parameter values at bind time, do so here. Errors returned from `HandleBind` are reported to the client before execution, and the connection enters error recovery mode (messages discarded until Sync). This gives the client a chance to re-bind with corrected values. + +### Step 8: Implement HandleCloseStatement / HandleClosePortal + +If stackql caches query plans or intermediate state, release them here. If not, the no-ops from step 1 are correct. + +## Error recovery + +The wire library handles error recovery automatically. If any `IExtendedQueryBackend` method returns an error: + +1. An `ErrorResponse` is sent to the client. +2. All subsequent messages are discarded until the client sends `Sync`. +3. `Sync` sends `ReadyForQuery('E')` (failed transaction status). +4. The client can then retry or issue new commands. + +Backend methods should return errors freely. They do not need to manage connection state. + +## Testing strategy + +Each step should be verified independently: + +1. **Step 1 (stubs)**: full robot test suite — no regressions. +2. **Step 3 (execute)**: pgx test with parameterised `SELECT $1::text` — returns correct value. +3. **Step 4 (describe)**: pgx `Prepare` + check `FieldDescriptions()` — column names and OIDs match. +4. **Step 5 (describe portal)**: pgx query with `QueryRow` — typed scan works without explicit type hints. +5. **Step 6 (parse)**: pgx query with untyped parameters — server resolves types, client formats correctly. + +For each step, a pgx integration test is the most realistic validation since pgx exercises the full Parse → Describe → Bind → Execute → Sync pipeline. + +## Common OIDs for reference + +| Type | OID | Go type | +|---------|------|---------------| +| bool | 16 | bool | +| int2 | 21 | int16 | +| int4 | 23 | int32 | +| int8 | 20 | int64 | +| float4 | 700 | float32 | +| float8 | 701 | float64 | +| text | 25 | string | +| varchar | 1043 | string | +| json | 114 | string/[]byte | +| jsonb | 3802 | string/[]byte | + +These are defined in `github.com/lib/pq/oid` as `oid.T_bool`, `oid.T_int4`, etc. From 3c1875b076157f57132c98d0e8cec0d7419c15fa Mon Sep 17 00:00:00 2001 From: General Kroll Date: Fri, 3 Apr 2026 00:30:12 +1100 Subject: [PATCH 02/11] pg-wire-extended-query-support Summary: - Support for `postgres` extended queries. --- internal/stackql/driver/driver.go | 76 +++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index 96d63745..a2a3ae22 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -4,6 +4,9 @@ import ( "bytes" "context" "fmt" + "regexp" + "strconv" + "strings" "github.com/stackql/any-sdk/pkg/logging" "github.com/stackql/any-sdk/public/sqlengine" @@ -19,9 +22,10 @@ import ( ) var ( - _ StackQLDriver = &basicStackQLDriver{} - _ sqlbackend.SQLBackendFactory = &basicStackQLDriverFactory{} - _ StackQLDriverFactory = &basicStackQLDriverFactory{} + _ StackQLDriver = &basicStackQLDriver{} + _ sqlbackend.IExtendedQueryBackend = &basicStackQLDriver{} + _ sqlbackend.SQLBackendFactory = &basicStackQLDriverFactory{} + _ StackQLDriverFactory = &basicStackQLDriverFactory{} ) type StackQLDriverFactory interface { @@ -194,6 +198,72 @@ func NewStackQLDriver(handlerCtx handler.HandlerContext) (StackQLDriver, error) }, nil } +func (dr *basicStackQLDriver) HandleParse( + ctx context.Context, stmtName string, query string, paramOIDs []uint32, +) ([]uint32, error) { + return paramOIDs, nil +} + +func (dr *basicStackQLDriver) HandleBind( + ctx context.Context, portalName string, stmtName string, + paramFormats []int16, paramValues [][]byte, resultFormats []int16, +) error { + return nil +} + +func (dr *basicStackQLDriver) HandleDescribeStatement( + ctx context.Context, stmtName string, query string, paramOIDs []uint32, +) ([]uint32, []sqldata.ISQLColumn, error) { + return paramOIDs, nil, nil +} + +func (dr *basicStackQLDriver) HandleDescribePortal( + ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, +) ([]sqldata.ISQLColumn, error) { + return nil, nil +} + +func (dr *basicStackQLDriver) HandleExecute( + ctx context.Context, portalName string, stmtName string, query string, + paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32, +) (sqldata.ISQLResultStream, error) { + resolved := substituteParams(query, paramFormats, paramValues) + return dr.HandleSimpleQuery(ctx, resolved) +} + +var paramPlaceholderRegex = regexp.MustCompile(`\$(\d+)`) + +// substituteParams replaces $1, $2, ... placeholders with their bound values. +// NULL parameters are substituted as the literal NULL. +// String values are single-quote escaped. +func substituteParams(query string, paramFormats []int16, paramValues [][]byte) string { + if len(paramValues) == 0 { + return query + } + return paramPlaceholderRegex.ReplaceAllStringFunc(query, func(match string) string { + idxStr := match[1:] // strip leading $ + idx, err := strconv.Atoi(idxStr) + if err != nil || idx < 1 || idx > len(paramValues) { + return match // leave unrecognised placeholders as-is + } + val := paramValues[idx-1] + if val == nil { + return "NULL" + } + text := string(val) + escaped := strings.ReplaceAll(text, "'", "''") + return "'" + escaped + "'" + }) +} + +func (dr *basicStackQLDriver) HandleCloseStatement(ctx context.Context, stmtName string) error { + return nil +} + +func (dr *basicStackQLDriver) HandleClosePortal(ctx context.Context, portalName string) error { + return nil +} + func (dr *basicStackQLDriver) processQueryOrQueries( handlerCtx handler.HandlerContext, ) ([]internaldto.ExecutorOutput, bool) { From 662e9e121b8c669c449a13ecee53b701d6320f5c Mon Sep 17 00:00:00 2001 From: General Kroll Date: Fri, 3 Apr 2026 10:19:25 +1100 Subject: [PATCH 03/11] next-stage --- internal/stackql/driver/driver.go | 47 +++++++++++++++++++- internal/stackql/plan/plan.go | 28 +++++++++--- internal/stackql/planbuilder/entrypoint.go | 5 +++ internal/stackql/planbuilder/plan_builder.go | 5 +++ 4 files changed, 76 insertions(+), 9 deletions(-) diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index a2a3ae22..f71b02ec 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -14,7 +14,9 @@ import ( "github.com/stackql/stackql/internal/stackql/acid/tsm_physio" "github.com/stackql/stackql/internal/stackql/handler" "github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto" + "github.com/stackql/stackql/internal/stackql/planbuilder" "github.com/stackql/stackql/internal/stackql/responsehandler" + "github.com/stackql/stackql/internal/stackql/typing" "github.com/stackql/stackql/internal/stackql/util" "github.com/stackql/stackql/pkg/txncounter" @@ -214,13 +216,54 @@ func (dr *basicStackQLDriver) HandleBind( func (dr *basicStackQLDriver) HandleDescribeStatement( ctx context.Context, stmtName string, query string, paramOIDs []uint32, ) ([]uint32, []sqldata.ISQLColumn, error) { - return paramOIDs, nil, nil + columns, err := dr.describeColumns(query) + if err != nil { + return nil, nil, err + } + return paramOIDs, columns, nil } func (dr *basicStackQLDriver) HandleDescribePortal( ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, ) ([]sqldata.ISQLColumn, error) { - return nil, nil + return dr.describeColumns(query) +} + +// describeColumns builds a query plan (without executing) and extracts +// result column metadata. For non-SELECT statements the plan carries no +// column metadata, so nil is returned and the wire library sends NoData. +func (dr *basicStackQLDriver) describeColumns(query string) ([]sqldata.ISQLColumn, error) { + clonedCtx := dr.handlerCtx.Clone() + clonedCtx.SetQuery(query) + clonedCtx.SetRawQuery(query) + pb := planbuilder.NewPlanBuilder(nil) + qPlan, err := pb.BuildPlanFromContext(clonedCtx) + if err != nil || qPlan == nil { + return nil, nil //nolint:nilerr // plan failure → NoData is acceptable + } + colMeta := qPlan.GetColumnMetadata() + if len(colMeta) == 0 { + return nil, nil + } + return columnMetadataToSQLColumns(colMeta), nil +} + +// columnMetadataToSQLColumns converts internal column metadata to wire protocol columns. +func columnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColumn { + table := sqldata.NewSQLTable(0, "") + result := make([]sqldata.ISQLColumn, len(cols)) + for i, col := range cols { + result[i] = sqldata.NewSQLColumn( + table, + col.GetIdentifier(), + 0, + uint32(col.GetColumnOID()), + -1, + 0, + "text", + ) + } + return result } func (dr *basicStackQLDriver) HandleExecute( diff --git a/internal/stackql/plan/plan.go b/internal/stackql/plan/plan.go index 4a99e8a1..ee48d76a 100644 --- a/internal/stackql/plan/plan.go +++ b/internal/stackql/plan/plan.go @@ -5,6 +5,7 @@ import ( "github.com/stackql/stackql/internal/stackql/acid/binlog" "github.com/stackql/stackql/internal/stackql/primitivegraph" + "github.com/stackql/stackql/internal/stackql/typing" "github.com/stackql/stackql-parser/go/vt/sqlparser" ) @@ -30,6 +31,10 @@ type Plan interface { // Get the undo log entry. GetUndoLog() (binlog.LogEntry, bool) + // Column metadata from query planning (available after plan build, before execution). + GetColumnMetadata() []typing.ColumnMetadata + SetColumnMetadata(columns []typing.ColumnMetadata) + // Setters SetType(t sqlparser.StatementType) SetStatement(statement sqlparser.Statement) @@ -59,13 +64,14 @@ type standardPlan struct { // Stores BindVars needed to be provided as part of expression rewriting sqlparser.BindVarNeeds - ExecCount uint64 // Count of times this plan was executed - ExecTime time.Duration // Total execution time - ShardQueries uint64 // Total number of shard queries - Rows uint64 // Total number of rows - Errors uint64 // Total number of errors - isCacheable bool - isReadOnly bool + ExecCount uint64 // Count of times this plan was executed + ExecTime time.Duration // Total execution time + ShardQueries uint64 // Total number of shard queries + Rows uint64 // Total number of rows + Errors uint64 // Total number of errors + isCacheable bool + isReadOnly bool + columnMetadata []typing.ColumnMetadata } func NewPlan( @@ -170,3 +176,11 @@ func (p *standardPlan) IsCacheable() bool { func (p *standardPlan) SetCacheable(isCacheable bool) { p.isCacheable = isCacheable } + +func (p *standardPlan) GetColumnMetadata() []typing.ColumnMetadata { + return p.columnMetadata +} + +func (p *standardPlan) SetColumnMetadata(columns []typing.ColumnMetadata) { + p.columnMetadata = columns +} diff --git a/internal/stackql/planbuilder/entrypoint.go b/internal/stackql/planbuilder/entrypoint.go index 57e00be6..168bc2b7 100644 --- a/internal/stackql/planbuilder/entrypoint.go +++ b/internal/stackql/planbuilder/entrypoint.go @@ -181,6 +181,11 @@ func (pb *standardPlanBuilder) BuildPlanFromContext(handlerCtx handler.HandlerCo qPlan.SetInstructions(pGBuilder.getPlanGraphHolder()) + // Extract column metadata from the plan for extended query protocol Describe support. + if selCtx := pGBuilder.getRootPrimitiveGenerator().GetPrimitiveComposer().GetSelectPreparedStatementCtx(); selCtx != nil { + qPlan.SetColumnMetadata(selCtx.GetNonControlColumns()) + } + if qPlan.GetInstructions() != nil { err = qPlan.GetInstructions().GetPrimitiveGraph().Optimise() if err != nil { diff --git a/internal/stackql/planbuilder/plan_builder.go b/internal/stackql/planbuilder/plan_builder.go index bb6796aa..f24ee372 100644 --- a/internal/stackql/planbuilder/plan_builder.go +++ b/internal/stackql/planbuilder/plan_builder.go @@ -50,6 +50,7 @@ func isPlanCacheEnabled() bool { type planGraphBuilder interface { setRootPrimitiveGenerator(primitivegenerator.PrimitiveGenerator) + getRootPrimitiveGenerator() primitivegenerator.PrimitiveGenerator pgInternal(planbuilderinput.PlanBuilderInput) error createInstructionFor(planbuilderinput.PlanBuilderInput) error nop(planbuilderinput.PlanBuilderInput) error @@ -78,6 +79,10 @@ func (pgb *standardPlanGraphBuilder) setRootPrimitiveGenerator( pgb.rootPrimitiveGenerator = primitiveGenerator } +func (pgb *standardPlanGraphBuilder) getRootPrimitiveGenerator() primitivegenerator.PrimitiveGenerator { + return pgb.rootPrimitiveGenerator +} + func (pgb *standardPlanGraphBuilder) getPlanGraphHolder() primitivegraph.PrimitiveGraphHolder { return pgb.planGraphHolder } From eabe3eb9ecd1dd0d62d3e786843854123e860c2b Mon Sep 17 00:00:00 2001 From: General Kroll Date: Fri, 3 Apr 2026 11:35:41 +1100 Subject: [PATCH 04/11] next-stage --- .golangci.yml | 4 ++ internal/stackql/driver/driver.go | 47 ++----------------- internal/stackql/planbuilder/entrypoint.go | 1 + internal/stackql/queryshape/queryshape.go | 52 ++++++++++++++++++++++ 4 files changed, 60 insertions(+), 44 deletions(-) create mode 100644 internal/stackql/queryshape/queryshape.go diff --git a/.golangci.yml b/.golangci.yml index a22a0655..0808cd07 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -196,6 +196,10 @@ linters: - lll - revive path: mcp_client\/cmd\/.*\.go + - linters: + - revive + - unparam + path: internal\/stackql\/driver\/.*\.go - linters: - revive path: internal\/stackql\/acid\/tsm_physio\/.*\.go diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index f71b02ec..e2c742f2 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -14,9 +14,8 @@ import ( "github.com/stackql/stackql/internal/stackql/acid/tsm_physio" "github.com/stackql/stackql/internal/stackql/handler" "github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto" - "github.com/stackql/stackql/internal/stackql/planbuilder" + "github.com/stackql/stackql/internal/stackql/queryshape" "github.com/stackql/stackql/internal/stackql/responsehandler" - "github.com/stackql/stackql/internal/stackql/typing" "github.com/stackql/stackql/internal/stackql/util" "github.com/stackql/stackql/pkg/txncounter" @@ -216,54 +215,14 @@ func (dr *basicStackQLDriver) HandleBind( func (dr *basicStackQLDriver) HandleDescribeStatement( ctx context.Context, stmtName string, query string, paramOIDs []uint32, ) ([]uint32, []sqldata.ISQLColumn, error) { - columns, err := dr.describeColumns(query) - if err != nil { - return nil, nil, err - } + columns := queryshape.InferResultColumns(dr.handlerCtx, query) return paramOIDs, columns, nil } func (dr *basicStackQLDriver) HandleDescribePortal( ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, ) ([]sqldata.ISQLColumn, error) { - return dr.describeColumns(query) -} - -// describeColumns builds a query plan (without executing) and extracts -// result column metadata. For non-SELECT statements the plan carries no -// column metadata, so nil is returned and the wire library sends NoData. -func (dr *basicStackQLDriver) describeColumns(query string) ([]sqldata.ISQLColumn, error) { - clonedCtx := dr.handlerCtx.Clone() - clonedCtx.SetQuery(query) - clonedCtx.SetRawQuery(query) - pb := planbuilder.NewPlanBuilder(nil) - qPlan, err := pb.BuildPlanFromContext(clonedCtx) - if err != nil || qPlan == nil { - return nil, nil //nolint:nilerr // plan failure → NoData is acceptable - } - colMeta := qPlan.GetColumnMetadata() - if len(colMeta) == 0 { - return nil, nil - } - return columnMetadataToSQLColumns(colMeta), nil -} - -// columnMetadataToSQLColumns converts internal column metadata to wire protocol columns. -func columnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColumn { - table := sqldata.NewSQLTable(0, "") - result := make([]sqldata.ISQLColumn, len(cols)) - for i, col := range cols { - result[i] = sqldata.NewSQLColumn( - table, - col.GetIdentifier(), - 0, - uint32(col.GetColumnOID()), - -1, - 0, - "text", - ) - } - return result + return queryshape.InferResultColumns(dr.handlerCtx, query), nil } func (dr *basicStackQLDriver) HandleExecute( diff --git a/internal/stackql/planbuilder/entrypoint.go b/internal/stackql/planbuilder/entrypoint.go index 168bc2b7..9c1b3e93 100644 --- a/internal/stackql/planbuilder/entrypoint.go +++ b/internal/stackql/planbuilder/entrypoint.go @@ -182,6 +182,7 @@ func (pb *standardPlanBuilder) BuildPlanFromContext(handlerCtx handler.HandlerCo qPlan.SetInstructions(pGBuilder.getPlanGraphHolder()) // Extract column metadata from the plan for extended query protocol Describe support. + //nolint:lll // acceptable if selCtx := pGBuilder.getRootPrimitiveGenerator().GetPrimitiveComposer().GetSelectPreparedStatementCtx(); selCtx != nil { qPlan.SetColumnMetadata(selCtx.GetNonControlColumns()) } diff --git a/internal/stackql/queryshape/queryshape.go b/internal/stackql/queryshape/queryshape.go new file mode 100644 index 00000000..8fa6ebf4 --- /dev/null +++ b/internal/stackql/queryshape/queryshape.go @@ -0,0 +1,52 @@ +package queryshape + +import ( + "github.com/stackql/psql-wire/pkg/sqldata" + "github.com/stackql/stackql/internal/stackql/handler" + "github.com/stackql/stackql/internal/stackql/planbuilder" + "github.com/stackql/stackql/internal/stackql/typing" +) + +// InferResultColumns analyses a SQL query and returns result column metadata +// without executing the query. It builds a query plan to resolve column +// projections and types, but does not run the plan. +// +// Returns nil when column metadata cannot be derived (e.g. DDL, mutations, +// or queries that fail planning). +func InferResultColumns( + handlerCtx handler.HandlerContext, + query string, +) []sqldata.ISQLColumn { + clonedCtx := handlerCtx.Clone() + clonedCtx.SetQuery(query) + clonedCtx.SetRawQuery(query) + pb := planbuilder.NewPlanBuilder(nil) + qPlan, err := pb.BuildPlanFromContext(clonedCtx) + if err != nil || qPlan == nil { + return nil + } + colMeta := qPlan.GetColumnMetadata() + if len(colMeta) == 0 { + return nil + } + return ColumnMetadataToSQLColumns(colMeta) +} + +// ColumnMetadataToSQLColumns converts internal column metadata to +// wire protocol ISQLColumn objects. +func ColumnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColumn { + table := sqldata.NewSQLTable(0, "") + result := make([]sqldata.ISQLColumn, len(cols)) + for i, col := range cols { + result[i] = sqldata.NewSQLColumn( + table, + col.GetIdentifier(), + 0, + uint32(col.GetColumnOID()), + -1, + 0, + "text", + ) + } + return result +} From ec471cae033851b2a8a7eafa9f16b2da671473f8 Mon Sep 17 00:00:00 2001 From: General Kroll Date: Fri, 3 Apr 2026 12:37:23 +1100 Subject: [PATCH 05/11] better-aot-schema-analysis --- .golangci.yml | 4 + internal/stackql/driver/driver.go | 8 +- internal/stackql/queryshape/queryshape.go | 151 ++++++++++++++++++++-- 3 files changed, 146 insertions(+), 17 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 0808cd07..5d8164fc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -200,6 +200,10 @@ linters: - revive - unparam path: internal\/stackql\/driver\/.*\.go + - linters: + - stylecheck + - gosec + path: internal\/stackql\/queryshape\/.*\.go - linters: - revive path: internal\/stackql\/acid\/tsm_physio\/.*\.go diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index e2c742f2..7d8342b8 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -71,6 +71,7 @@ func (sdf *basicStackQLDriverFactory) newSQLDriver() (StackQLDriver, error) { debugBuf: buf, handlerCtx: clonedCtx, txnOrchestrator: txnOrchestrator, + shapeInferrer: queryshape.NewInferrer(clonedCtx), } return rv, nil } @@ -134,6 +135,7 @@ type basicStackQLDriver struct { debugBuf *bytes.Buffer handlerCtx handler.HandlerContext txnOrchestrator tsm_physio.Orchestrator + shapeInferrer queryshape.Inferrer } func (dr *basicStackQLDriver) GetDebugStr() string { @@ -196,6 +198,7 @@ func NewStackQLDriver(handlerCtx handler.HandlerContext) (StackQLDriver, error) return &basicStackQLDriver{ handlerCtx: handlerCtx, txnOrchestrator: txnOrchestrator, + shapeInferrer: queryshape.NewInferrer(handlerCtx), }, nil } @@ -215,14 +218,13 @@ func (dr *basicStackQLDriver) HandleBind( func (dr *basicStackQLDriver) HandleDescribeStatement( ctx context.Context, stmtName string, query string, paramOIDs []uint32, ) ([]uint32, []sqldata.ISQLColumn, error) { - columns := queryshape.InferResultColumns(dr.handlerCtx, query) - return paramOIDs, columns, nil + return paramOIDs, dr.shapeInferrer.InferResultColumns(query), nil } func (dr *basicStackQLDriver) HandleDescribePortal( ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, ) ([]sqldata.ISQLColumn, error) { - return queryshape.InferResultColumns(dr.handlerCtx, query), nil + return dr.shapeInferrer.InferResultColumns(query), nil } func (dr *basicStackQLDriver) HandleExecute( diff --git a/internal/stackql/queryshape/queryshape.go b/internal/stackql/queryshape/queryshape.go index 8fa6ebf4..af05888c 100644 --- a/internal/stackql/queryshape/queryshape.go +++ b/internal/stackql/queryshape/queryshape.go @@ -1,23 +1,90 @@ +// Package queryshape provides result column metadata inference for SQL queries +// without executing them. +// +// Type inference sources vary by relation kind: +// +// - Materialized views and user space tables: column metadata is stored +// alongside the DDL in system tables (__iql__.materialized_views.columns, +// __iql__.tables.columns). OIDs, widths, and types are directly available. +// +// - Views: the view DDL (a SELECT) is stored in __iql__.views. Parsing the +// DDL and recursively analysing the projection list yields column shapes. +// This currently delegates to the plan builder. +// +// - Direct queries and subqueries: column types are a function of provider +// method schemas, applied SQL function signatures, and RDBMS expression +// rules. This currently delegates to the plan builder. package queryshape import ( + "github.com/lib/pq/oid" "github.com/stackql/psql-wire/pkg/sqldata" + "github.com/stackql/stackql-parser/go/vt/sqlparser" "github.com/stackql/stackql/internal/stackql/handler" "github.com/stackql/stackql/internal/stackql/planbuilder" + "github.com/stackql/stackql/internal/stackql/sql_system" "github.com/stackql/stackql/internal/stackql/typing" ) -// InferResultColumns analyses a SQL query and returns result column metadata -// without executing the query. It builds a query plan to resolve column -// projections and types, but does not run the plan. -// -// Returns nil when column metadata cannot be derived (e.g. DDL, mutations, -// or queries that fail planning). -func InferResultColumns( - handlerCtx handler.HandlerContext, - query string, -) []sqldata.ISQLColumn { - clonedCtx := handlerCtx.Clone() +// Inferrer analyses SQL queries and returns result column metadata +// without executing them. +type Inferrer interface { + // InferResultColumns returns wire-protocol column metadata for the + // given query. Returns nil when columns cannot be derived (DDL, + // mutations, planning failures). + InferResultColumns(query string) []sqldata.ISQLColumn +} + +// NewInferrer creates a new query shape inferrer backed by the given +// handler context. +func NewInferrer(handlerCtx handler.HandlerContext) Inferrer { + return &standardInferrer{ + handlerCtx: handlerCtx, + sqlSystem: handlerCtx.GetSQLSystem(), + } +} + +type standardInferrer struct { + handlerCtx handler.HandlerContext + sqlSystem sql_system.SQLSystem +} + +func (si *standardInferrer) InferResultColumns(query string) []sqldata.ISQLColumn { + // Try stored relation metadata first (cheapest path). + if cols := si.inferFromStoredRelation(query); cols != nil { + return cols + } + // Fall back to plan-based inference for direct queries, subqueries, + // and views whose columns require provider schema resolution. + return si.inferFromPlan(query) +} + +// inferFromStoredRelation checks whether the query is a simple +// SELECT against a single materialized view or user space table +// whose column metadata is already stored. If so, the columns +// are returned directly from the DTO without planning. +func (si *standardInferrer) inferFromStoredRelation(query string) []sqldata.ISQLColumn { + tableName := extractSingleTableName(query) + if tableName == "" { + return nil + } + // Materialized views carry stored column metadata with OIDs. + if dto, ok := si.sqlSystem.GetMaterializedViewByName(tableName); ok { + return relationalColumnsToSQLColumns(dto.GetColumns()) + } + // User space tables also carry stored column metadata. + if dto, ok := si.sqlSystem.GetPhysicalTableByName(tableName); ok { + return relationalColumnsToSQLColumns(dto.GetColumns()) + } + return nil +} + +// inferFromPlan builds a query plan (without executing) and extracts +// column metadata from it. This handles views, subqueries, and +// direct provider queries where types derive from method schemas +// and SQL function signatures. +func (si *standardInferrer) inferFromPlan(query string) []sqldata.ISQLColumn { + clonedCtx := si.handlerCtx.Clone() clonedCtx.SetQuery(query) clonedCtx.SetRawQuery(query) pb := planbuilder.NewPlanBuilder(nil) @@ -29,12 +96,68 @@ func InferResultColumns( if len(colMeta) == 0 { return nil } - return ColumnMetadataToSQLColumns(colMeta) + return columnMetadataToSQLColumns(colMeta) +} + +// extractSingleTableName does a lightweight parse to detect queries +// of the form "SELECT ... FROM ..." and returns the +// table name. Returns "" for anything more complex (joins, subqueries, etc). +func extractSingleTableName(query string) string { + stmt, err := sqlparser.Parse(query) + if err != nil { + return "" + } + sel, ok := stmt.(*sqlparser.Select) + if !ok || sel == nil { + return "" + } + if len(sel.From) != 1 { + return "" + } + aliased, ok := sel.From[0].(*sqlparser.AliasedTableExpr) + if !ok { + return "" + } + tableName, ok := aliased.Expr.(sqlparser.TableName) + if !ok { + return "" + } + // Only unqualified table names (no provider.service.resource). + if !tableName.Qualifier.IsEmpty() { + return "" + } + return tableName.Name.GetRawVal() +} + +// relationalColumnsToSQLColumns converts stored RelationalColumn +// metadata to wire protocol ISQLColumn objects. +func relationalColumnsToSQLColumns(cols []typing.RelationalColumn) []sqldata.ISQLColumn { + if len(cols) == 0 { + return nil + } + table := sqldata.NewSQLTable(0, "") + result := make([]sqldata.ISQLColumn, len(cols)) + for i, col := range cols { + colOID := oid.T_text + if storedOID, ok := col.GetOID(); ok { + colOID = storedOID + } + result[i] = sqldata.NewSQLColumn( + table, + col.GetIdentifier(), + 0, + uint32(colOID), + int16(col.GetWidth()), + 0, + "text", + ) + } + return result } -// ColumnMetadataToSQLColumns converts internal column metadata to +// columnMetadataToSQLColumns converts internal ColumnMetadata to // wire protocol ISQLColumn objects. -func ColumnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColumn { +func columnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColumn { table := sqldata.NewSQLTable(0, "") result := make([]sqldata.ISQLColumn, len(cols)) for i, col := range cols { From c55a3d6b61856097fa2c4dada76df2f18af5951f Mon Sep 17 00:00:00 2001 From: General Kroll Date: Fri, 3 Apr 2026 12:39:07 +1100 Subject: [PATCH 06/11] - Safe integer casting for runtime vs RDBMS type conversions. --- .golangci.yml | 1 - internal/stackql/queryshape/queryshape.go | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 5d8164fc..5ad7f07e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -202,7 +202,6 @@ linters: path: internal\/stackql\/driver\/.*\.go - linters: - stylecheck - - gosec path: internal\/stackql\/queryshape\/.*\.go - linters: - revive diff --git a/internal/stackql/queryshape/queryshape.go b/internal/stackql/queryshape/queryshape.go index af05888c..6f9b62f2 100644 --- a/internal/stackql/queryshape/queryshape.go +++ b/internal/stackql/queryshape/queryshape.go @@ -17,6 +17,8 @@ package queryshape import ( + "math" + "github.com/lib/pq/oid" "github.com/stackql/psql-wire/pkg/sqldata" "github.com/stackql/stackql-parser/go/vt/sqlparser" @@ -142,12 +144,16 @@ func relationalColumnsToSQLColumns(cols []typing.RelationalColumn) []sqldata.ISQ if storedOID, ok := col.GetOID(); ok { colOID = storedOID } + w := col.GetWidth() + if w > math.MaxInt16 || w < math.MinInt16 { + w = -1 + } result[i] = sqldata.NewSQLColumn( table, col.GetIdentifier(), 0, uint32(colOID), - int16(col.GetWidth()), + int16(w), //nolint:gosec // bounds checked above 0, "text", ) From aa2635f8244cf82e1a2bd26c0fa4c7aa0de807b4 Mon Sep 17 00:00:00 2001 From: General Kroll Date: Sat, 4 Apr 2026 00:09:16 +1100 Subject: [PATCH 07/11] re run cii --- docs/live_context.md | 114 +++++++++++++++ internal/stackql/driver/driver.go | 86 +++++++---- internal/stackql/paramdecoder/paramdecoder.go | 133 +++++++++++++++++ .../stackql/paramdecoder/paramdecoder_test.go | 137 ++++++++++++++++++ internal/stackql/psqlwire/psqlwire.go | 87 +++++++++++ internal/stackql/queryshape/queryshape.go | 50 +++++++ .../stackql/queryshape/queryshape_test.go | 70 +++++++++ .../testdata/extract_table_cases.json | 82 +++++++++++ .../testdata/substitute_params_cases.json | 68 +++++++++ internal/stackql/typing/oid_mapping_test.go | 50 +++++++ .../stackql/typing/relayed_column_metadata.go | 2 - .../typing/standard_column_metadata.go | 2 - 12 files changed, 847 insertions(+), 34 deletions(-) create mode 100644 docs/live_context.md create mode 100644 internal/stackql/paramdecoder/paramdecoder.go create mode 100644 internal/stackql/paramdecoder/paramdecoder_test.go create mode 100644 internal/stackql/queryshape/queryshape_test.go create mode 100644 internal/stackql/queryshape/testdata/extract_table_cases.json create mode 100644 internal/stackql/queryshape/testdata/substitute_params_cases.json create mode 100644 internal/stackql/typing/oid_mapping_test.go diff --git a/docs/live_context.md b/docs/live_context.md new file mode 100644 index 00000000..70953ea1 --- /dev/null +++ b/docs/live_context.md @@ -0,0 +1,114 @@ +# Live Context: Extended Query Protocol Implementation + +## Date: 2026-04-03 + +## Current State + +### What's been done + +1. **Indirect joins tests** (complete, merged-ready): + - 8 new robot tests covering 3-way and 4-way INNER JOIN + LEFT OUTER JOIN across views, materialized views, subqueries, provider tables + - All use `Should Stackql Exec Inline Equal Both Streams` with exact output matching + - LEFT OUTER JOIN tests prove NULL behavior with partial matches + - `docs/views.md` updated with supported combinations + +2. **Extended query protocol stubs** (complete): + - `basicStackQLDriver` implements `IExtendedQueryBackend` with compile-time check + - `HandleParse`: passthrough (returns client OIDs as-is) + - `HandleBind`: no-op + - `HandleExecute`: `queryshape.SubstituteParams` replaces `$N` → `HandleSimpleQuery` + - `HandleDescribeStatement/Portal`: delegates to `queryshape.Inferrer` + - `HandleClose*`: no-ops + - 407/407 robot tests pass + +3. **`queryshape` package** (complete): + - `internal/stackql/queryshape/queryshape.go` + - Public `Inferrer` interface, private `standardInferrer` struct + - `InferResultColumns(query)` with two paths: + - `inferFromStoredRelation`: reads MV/table column metadata from stored DTOs (cheap) + - `inferFromPlan`: builds plan via `planbuilder.BuildPlanFromContext` (no execution) + - `SubstituteParams`: moved here from driver, replaces `$N` with bound values + - `extractSingleTableName`: lightweight sqlparser-based single-table detection + - Unit tests in `queryshape_test.go` with JSON testdata + +4. **Plan column metadata** (complete): + - `plan.Plan` has `GetColumnMetadata()`/`SetColumnMetadata()` + - Set in `planbuilder/entrypoint.go` from `GetSelectPreparedStatementCtx().GetNonControlColumns()` + - `planGraphBuilder` interface has `getRootPrimitiveGenerator()` + +5. **OID fidelity + value coercion** (IN PROGRESS — 10 failures remaining): + - Finer OID mapping in `typing/standard_column_metadata.go`: + - `getOidForSchema`: `integer`→`T_int8`, `boolean`→`T_bool`, `number`→`T_numeric` + - `getOidForParserColType`: split into `T_int2/T_int4/T_int8/T_float4/T_float8/T_json/T_jsonb` + - Finer OID mapping in `typing/relayed_column_metadata.go` + - Value coercion in `internal/stackql/psqlwire/psqlwire.go`: + - `coerceForOID()` function converts string/[]byte from RDBMS to Go types pgtype expects + - Applied in `ExtractRowElement` for non-text, non-numeric OIDs + - Existing `shimNumericElement` preserved for `"numeric"` type + - **397/407 pass, 10 failures remain** — need to check what those 10 are + +### What's next (from the plan) + +**Immediate**: Fix remaining 10 test failures from OID changes. Check what OID/coercion path they hit. + +**Phase 2**: `paramresolver` package +- Resolve `$N` placeholder OIDs from method schemas during `HandleParse` +- Add `ParameterOIDs []uint32` to `plan.Plan` +- Populate during plan building in `entrypoint.go` + +**Phase 3**: Stateful driver + `paramdecoder` package +- `paramdecoder`: decode binary-format params using `jackc/pgtype` +- Statement/portal caches in `basicStackQLDriver`: + - `HandleParse`: resolve OIDs + infer columns → cache in `stmtCache` + - `HandleDescribeStatement`: return from cache (no re-planning) + - `HandleBind`: record portal→statement mapping + - `HandleExecute`: look up portal → decode params → substitute → execute + - `HandleClose*`: delete from caches + +**Phase 4** (separate, psql-wire repo): Respect `resultFormats` from Bind instead of hardcoding `TextFormat` + +### Key files modified + +| File | Status | Description | +|------|--------|-------------| +| `internal/stackql/driver/driver.go` | Modified | IExtendedQueryBackend impl, shapeInferrer field | +| `internal/stackql/queryshape/queryshape.go` | New | Inferrer interface, SubstituteParams | +| `internal/stackql/queryshape/queryshape_test.go` | New | Unit tests with JSON testdata | +| `internal/stackql/queryshape/testdata/*.json` | New | Test cases | +| `internal/stackql/psqlwire/psqlwire.go` | Modified | coerceForOID() value coercion | +| `internal/stackql/typing/standard_column_metadata.go` | Modified | Finer OID mapping | +| `internal/stackql/typing/relayed_column_metadata.go` | Modified | Finer OID mapping | +| `internal/stackql/typing/oid_mapping_test.go` | New | OID mapping unit tests | +| `internal/stackql/plan/plan.go` | Modified | ColumnMetadata field + getter/setter | +| `internal/stackql/planbuilder/entrypoint.go` | Modified | Extract column metadata during plan build | +| `internal/stackql/planbuilder/plan_builder.go` | Modified | getRootPrimitiveGenerator() on interface | +| `test/robot/functional/stackql_mocked_from_cmd_line.robot` | Modified | 8 new indirect join tests + 3-way test | +| `docs/views.md` | Modified | Updated supported join combinations | + +### Key architectural decisions + +- `queryshape.Inferrer` is the single entry point for ahead-of-time schema inference +- Plan building (without execution) is the mechanism for inferring column types from provider schemas +- Value coercion in `psqlwire/psqlwire.go` bridges sqlite's string-heavy output to pgtype's typed encoders +- OID fidelity is now enabled: integer→T_int8, boolean→T_bool, fine-grained parser col types +- The `shimNumericElement`/`shimNumericTextBytes` hacks are preserved for backward compatibility with the "numeric" pgtype path + +### Phase 3 Complete (2026-04-04) + +All Handle* methods now flow through stateful caches: +- `HandleParse`: infers columns, caches in `stmtCache[stmtName]` +- `HandleDescribeStatement`: returns from cache (no re-planning) +- `HandleDescribePortal`: looks up portal→statement→columns +- `HandleBind`: records portal→statement in `portalCache` +- `HandleExecute`: looks up portal→OIDs, decodes params via `paramdecoder`, substitutes, executes +- `HandleClose*`: cleans up caches + +New packages: +- `internal/stackql/paramdecoder/` — decodes text AND binary format params (int2/4/8, float4/8, bool, timestamp, text) +- Value coercion function `coerceForOID` in `psqlwire.go` — ready but not active (deferred to Phase 4 with OID fidelity) + +### Remaining work + +- **Phase 1 (OID fidelity)**: finer OIDs (integer→T_int8, bool→T_bool) break 10 tests because pgtype's text encoder formats values differently. Needs psql-wire change to bypass pgtype.Set() for text format and write strings directly. Deferred. +- **Phase 2 (paramresolver)**: resolve $N placeholder OIDs from method schemas. pgx works without this (defaults to text). Enhancement. +- **Phase 4 (psql-wire)**: respect resultFormats from Bind, enable binary result encoding. Separate repo. diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index 7d8342b8..d1bea81a 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -4,9 +4,6 @@ import ( "bytes" "context" "fmt" - "regexp" - "strconv" - "strings" "github.com/stackql/any-sdk/pkg/logging" "github.com/stackql/any-sdk/public/sqlengine" @@ -14,6 +11,7 @@ import ( "github.com/stackql/stackql/internal/stackql/acid/tsm_physio" "github.com/stackql/stackql/internal/stackql/handler" "github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto" + "github.com/stackql/stackql/internal/stackql/paramdecoder" "github.com/stackql/stackql/internal/stackql/queryshape" "github.com/stackql/stackql/internal/stackql/responsehandler" "github.com/stackql/stackql/internal/stackql/util" @@ -72,6 +70,9 @@ func (sdf *basicStackQLDriverFactory) newSQLDriver() (StackQLDriver, error) { handlerCtx: clonedCtx, txnOrchestrator: txnOrchestrator, shapeInferrer: queryshape.NewInferrer(clonedCtx), + paramDecoder: paramdecoder.NewDecoder(), + stmtCache: make(map[string]*stmtMeta), + portalCache: make(map[string]*portalMeta), } return rv, nil } @@ -131,11 +132,24 @@ func (dr *basicStackQLDriver) ProcessQuery(query string) { } } +type stmtMeta struct { + query string + paramOIDs []uint32 + columns []sqldata.ISQLColumn +} + +type portalMeta struct { + stmtName string +} + type basicStackQLDriver struct { debugBuf *bytes.Buffer handlerCtx handler.HandlerContext txnOrchestrator tsm_physio.Orchestrator shapeInferrer queryshape.Inferrer + paramDecoder paramdecoder.Decoder + stmtCache map[string]*stmtMeta + portalCache map[string]*portalMeta } func (dr *basicStackQLDriver) GetDebugStr() string { @@ -199,12 +213,22 @@ func NewStackQLDriver(handlerCtx handler.HandlerContext) (StackQLDriver, error) handlerCtx: handlerCtx, txnOrchestrator: txnOrchestrator, shapeInferrer: queryshape.NewInferrer(handlerCtx), + paramDecoder: paramdecoder.NewDecoder(), + stmtCache: make(map[string]*stmtMeta), + portalCache: make(map[string]*portalMeta), }, nil } func (dr *basicStackQLDriver) HandleParse( ctx context.Context, stmtName string, query string, paramOIDs []uint32, ) ([]uint32, error) { + // Infer result columns at parse time and cache for Describe/Execute. + columns := dr.shapeInferrer.InferResultColumns(query) + dr.stmtCache[stmtName] = &stmtMeta{ + query: query, + paramOIDs: paramOIDs, + columns: columns, + } return paramOIDs, nil } @@ -212,18 +236,31 @@ func (dr *basicStackQLDriver) HandleBind( ctx context.Context, portalName string, stmtName string, paramFormats []int16, paramValues [][]byte, resultFormats []int16, ) error { + dr.portalCache[portalName] = &portalMeta{ + stmtName: stmtName, + } return nil } func (dr *basicStackQLDriver) HandleDescribeStatement( ctx context.Context, stmtName string, query string, paramOIDs []uint32, ) ([]uint32, []sqldata.ISQLColumn, error) { - return paramOIDs, dr.shapeInferrer.InferResultColumns(query), nil + if cached, ok := dr.stmtCache[stmtName]; ok { + return cached.paramOIDs, cached.columns, nil + } + // Fallback: infer on the fly (shouldn't happen if Parse was called first). + columns := dr.shapeInferrer.InferResultColumns(query) + return paramOIDs, columns, nil } func (dr *basicStackQLDriver) HandleDescribePortal( ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, ) ([]sqldata.ISQLColumn, error) { + if portal, ok := dr.portalCache[portalName]; ok { + if cached, ok := dr.stmtCache[portal.stmtName]; ok { + return cached.columns, nil + } + } return dr.shapeInferrer.InferResultColumns(query), nil } @@ -231,40 +268,29 @@ func (dr *basicStackQLDriver) HandleExecute( ctx context.Context, portalName string, stmtName string, query string, paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32, ) (sqldata.ISQLResultStream, error) { - resolved := substituteParams(query, paramFormats, paramValues) - return dr.HandleSimpleQuery(ctx, resolved) -} - -var paramPlaceholderRegex = regexp.MustCompile(`\$(\d+)`) - -// substituteParams replaces $1, $2, ... placeholders with their bound values. -// NULL parameters are substituted as the literal NULL. -// String values are single-quote escaped. -func substituteParams(query string, paramFormats []int16, paramValues [][]byte) string { - if len(paramValues) == 0 { - return query - } - return paramPlaceholderRegex.ReplaceAllStringFunc(query, func(match string) string { - idxStr := match[1:] // strip leading $ - idx, err := strconv.Atoi(idxStr) - if err != nil || idx < 1 || idx > len(paramValues) { - return match // leave unrecognised placeholders as-is - } - val := paramValues[idx-1] - if val == nil { - return "NULL" + // Look up cached param OIDs for format-aware decoding. + var paramOIDs []uint32 + if portal, ok := dr.portalCache[portalName]; ok { + if cached, ok := dr.stmtCache[portal.stmtName]; ok { + paramOIDs = cached.paramOIDs } - text := string(val) - escaped := strings.ReplaceAll(text, "'", "''") - return "'" + escaped + "'" - }) + } + // Decode params (handles both text and binary formats). + decodedStrings, err := dr.paramDecoder.DecodeParams(paramOIDs, paramFormats, paramValues) + if err != nil { + return nil, fmt.Errorf("parameter decoding error: %w", err) + } + resolved := queryshape.SubstituteDecodedParams(query, decodedStrings) + return dr.HandleSimpleQuery(ctx, resolved) } func (dr *basicStackQLDriver) HandleCloseStatement(ctx context.Context, stmtName string) error { + delete(dr.stmtCache, stmtName) return nil } func (dr *basicStackQLDriver) HandleClosePortal(ctx context.Context, portalName string) error { + delete(dr.portalCache, portalName) return nil } diff --git a/internal/stackql/paramdecoder/paramdecoder.go b/internal/stackql/paramdecoder/paramdecoder.go new file mode 100644 index 00000000..07c1eda7 --- /dev/null +++ b/internal/stackql/paramdecoder/paramdecoder.go @@ -0,0 +1,133 @@ +// Package paramdecoder decodes parameter values from their wire format +// (text or binary) into string representations suitable for SQL substitution. +package paramdecoder + +import ( + "encoding/binary" + "fmt" + "math" + "strconv" + "time" + + "github.com/lib/pq/oid" +) + +// Decoder decodes raw parameter bytes according to their format codes +// and OIDs, returning string representations for each. +type Decoder interface { + DecodeParams(paramOIDs []uint32, paramFormats []int16, paramValues [][]byte) ([]string, error) +} + +// NewDecoder creates a new parameter decoder. +func NewDecoder() Decoder { + return &standardDecoder{} +} + +type standardDecoder struct{} + +func (d *standardDecoder) DecodeParams( + paramOIDs []uint32, paramFormats []int16, paramValues [][]byte, +) ([]string, error) { + result := make([]string, len(paramValues)) + for i, val := range paramValues { + if val == nil { + result[i] = "NULL" + continue + } + format := resolveFormat(paramFormats, i) + paramOID := oid.Oid(0) + if i < len(paramOIDs) { + paramOID = oid.Oid(paramOIDs[i]) + } + decoded, err := decodeParam(paramOID, format, val) + if err != nil { + return nil, fmt.Errorf("parameter $%d: %w", i+1, err) + } + result[i] = decoded + } + return result, nil +} + +// resolveFormat returns the format code for parameter at index i. +// Per postgres protocol: empty = all text, length 1 = applies to all, +// otherwise per-parameter. +func resolveFormat(formats []int16, i int) int16 { + if len(formats) == 0 { + return 0 // text + } + if len(formats) == 1 { + return formats[0] + } + if i < len(formats) { + return formats[i] + } + return 0 // text +} + +// decodeParam decodes a single parameter value. +// Format 0 = text (bytes are UTF-8), format 1 = binary (OID-specific encoding). +func decodeParam(paramOID oid.Oid, format int16, val []byte) (string, error) { + if format == 0 { + // Text format: raw bytes are the UTF-8 string representation. + return string(val), nil + } + // Binary format: decode based on OID. + return decodeBinary(paramOID, val) +} + +// decodeBinary decodes a binary-encoded parameter value to its string representation. +// +//nolint:cyclop // switch over OIDs is inherently branchy +func decodeBinary(paramOID oid.Oid, val []byte) (string, error) { + switch paramOID { + case oid.T_bool: + if len(val) != 1 { + return "", fmt.Errorf("bool: expected 1 byte, got %d", len(val)) + } + if val[0] != 0 { + return "true", nil + } + return "false", nil + case oid.T_int2: + if len(val) != 2 { + return "", fmt.Errorf("int2: expected 2 bytes, got %d", len(val)) + } + return strconv.FormatInt(int64(int16(binary.BigEndian.Uint16(val))), 10), nil //nolint:gosec // deliberate int16 conversion + case oid.T_int4: + if len(val) != 4 { + return "", fmt.Errorf("int4: expected 4 bytes, got %d", len(val)) + } + return strconv.FormatInt(int64(int32(binary.BigEndian.Uint32(val))), 10), nil //nolint:gosec // deliberate int32 conversion + case oid.T_int8: + if len(val) != 8 { + return "", fmt.Errorf("int8: expected 8 bytes, got %d", len(val)) + } + return strconv.FormatInt(int64(binary.BigEndian.Uint64(val)), 10), nil + case oid.T_float4: + if len(val) != 4 { + return "", fmt.Errorf("float4: expected 4 bytes, got %d", len(val)) + } + bits := binary.BigEndian.Uint32(val) + return strconv.FormatFloat(float64(math.Float32frombits(bits)), 'f', -1, 32), nil + case oid.T_float8: + if len(val) != 8 { + return "", fmt.Errorf("float8: expected 8 bytes, got %d", len(val)) + } + bits := binary.BigEndian.Uint64(val) + return strconv.FormatFloat(math.Float64frombits(bits), 'f', -1, 64), nil + case oid.T_timestamp, oid.T_timestamptz: + if len(val) != 8 { + return "", fmt.Errorf("timestamp: expected 8 bytes, got %d", len(val)) + } + // Postgres binary timestamp: microseconds since 2000-01-01 00:00:00 UTC. + microseconds := int64(binary.BigEndian.Uint64(val)) + pgEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + ts := pgEpoch.Add(time.Duration(microseconds) * time.Microsecond) + return ts.Format("2006-01-02 15:04:05.999999"), nil + case oid.T_text, oid.T_varchar, oid.T_name: + return string(val), nil + default: + // Unknown OID: treat as text (safe fallback). + return string(val), nil + } +} diff --git a/internal/stackql/paramdecoder/paramdecoder_test.go b/internal/stackql/paramdecoder/paramdecoder_test.go new file mode 100644 index 00000000..438eca50 --- /dev/null +++ b/internal/stackql/paramdecoder/paramdecoder_test.go @@ -0,0 +1,137 @@ +package paramdecoder + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/lib/pq/oid" +) + +func TestDecodeTextParams(t *testing.T) { + d := NewDecoder() + results, err := d.DecodeParams( + []uint32{uint32(oid.T_text), uint32(oid.T_text)}, + []int16{0}, // all text + [][]byte{[]byte("hello"), []byte("world")}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "hello" || results[1] != "world" { + t.Errorf("got %v", results) + } +} + +func TestDecodeNullParam(t *testing.T) { + d := NewDecoder() + results, err := d.DecodeParams( + []uint32{uint32(oid.T_text)}, + nil, + [][]byte{nil}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "NULL" { + t.Errorf("got %q, want NULL", results[0]) + } +} + +func TestDecodeBinaryInt4(t *testing.T) { + d := NewDecoder() + val := make([]byte, 4) + binary.BigEndian.PutUint32(val, uint32(42)) + results, err := d.DecodeParams( + []uint32{uint32(oid.T_int4)}, + []int16{1}, // binary + [][]byte{val}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "42" { + t.Errorf("got %q, want 42", results[0]) + } +} + +func TestDecodeBinaryInt8(t *testing.T) { + d := NewDecoder() + val := make([]byte, 8) + binary.BigEndian.PutUint64(val, uint64(9999999999)) + results, err := d.DecodeParams( + []uint32{uint32(oid.T_int8)}, + []int16{1}, + [][]byte{val}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "9999999999" { + t.Errorf("got %q, want 9999999999", results[0]) + } +} + +func TestDecodeBinaryFloat8(t *testing.T) { + d := NewDecoder() + val := make([]byte, 8) + binary.BigEndian.PutUint64(val, math.Float64bits(3.14)) + results, err := d.DecodeParams( + []uint32{uint32(oid.T_float8)}, + []int16{1}, + [][]byte{val}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "3.14" { + t.Errorf("got %q, want 3.14", results[0]) + } +} + +func TestDecodeBinaryBool(t *testing.T) { + d := NewDecoder() + results, err := d.DecodeParams( + []uint32{uint32(oid.T_bool), uint32(oid.T_bool)}, + []int16{1}, + [][]byte{{1}, {0}}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "true" || results[1] != "false" { + t.Errorf("got %v", results) + } +} + +func TestDecodeMixedFormats(t *testing.T) { + d := NewDecoder() + int4Val := make([]byte, 4) + binary.BigEndian.PutUint32(int4Val, uint32(100)) + results, err := d.DecodeParams( + []uint32{uint32(oid.T_text), uint32(oid.T_int4)}, + []int16{0, 1}, // first text, second binary + [][]byte{[]byte("hello"), int4Val}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "hello" || results[1] != "100" { + t.Errorf("got %v", results) + } +} + +func TestDecodeUnknownOIDBinaryFallsBackToText(t *testing.T) { + d := NewDecoder() + results, err := d.DecodeParams( + []uint32{99999}, + []int16{1}, // binary + [][]byte{[]byte("raw-bytes")}, + ) + if err != nil { + t.Fatal(err) + } + if results[0] != "raw-bytes" { + t.Errorf("got %q, want raw-bytes", results[0]) + } +} diff --git a/internal/stackql/psqlwire/psqlwire.go b/internal/stackql/psqlwire/psqlwire.go index 8b2cf245..1f7b1c64 100644 --- a/internal/stackql/psqlwire/psqlwire.go +++ b/internal/stackql/psqlwire/psqlwire.go @@ -98,6 +98,12 @@ func ExtractRowElement(column sqldata.ISQLColumn, src interface{}, ci *pgtype.Co processedElement = shimNumericElement(src) } // end hack + // NOTE: coerceForOID is available for binary format encoding (Phase 4). + // For text format (current default), string values pass through to + // pgtype's text encoder which handles all types correctly. + // Coercing to native Go types here would change the text encoding + // format (e.g. bool: "t"→"true", int: quoted→unquoted). + err := typed.Value.Set(processedElement) if err != nil { return nil, err @@ -122,6 +128,87 @@ func ExtractRowElement(column sqldata.ISQLColumn, src interface{}, ci *pgtype.Co return bb, nil } +// coerceForOID converts a value (often a string or []byte from the RDBMS) +// into the Go type that pgtype expects for the given type name. +// This bridges the gap between sqlite's string-heavy output and pgtype's +// type-specific Set() requirements. +// +//nolint:gocyclo,cyclop // switch over type names is inherently branchy +func coerceForOID(typeName string, val interface{}) interface{} { + if val == nil { + return nil + } + s, isStr := valToString(val) + if !isStr { + return val // already a native Go type; pass through + } + if s == "null" || s == "NULL" || s == "" { + return nil + } + switch typeName { + case "numeric": + f, err := strconv.ParseFloat(s, 64) + if err == nil { + return f + } + return val + case "int2": + i, err := strconv.ParseInt(s, 10, 16) + if err == nil { + return int16(i) + } + return val + case "int4": + i, err := strconv.ParseInt(s, 10, 32) + if err == nil { + return int32(i) + } + return val + case "int8": + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + return i + } + return val + case "float4": + f, err := strconv.ParseFloat(s, 32) + if err == nil { + return float32(f) + } + return val + case "float8": + f, err := strconv.ParseFloat(s, 64) + if err == nil { + return f + } + return val + case "bool": + switch s { + case "true", "TRUE", "t", "1", "yes": + return true + case "false", "FALSE", "f", "0", "no": + return false + } + return val + case "json", "jsonb": + return s // pass as string; pgtype text encoder handles it + default: + return val // text, varchar, timestamp, etc. — string is fine + } +} + +// valToString extracts a string from string or []byte values. +func valToString(val interface{}) (string, bool) { + switch v := val.(type) { + case string: + return v, true + case []byte: + return string(v), true + default: + return "", false + } +} + func getFormatCode(fc string) (postgreswire.FormatCode, error) { switch fc { case "TextFormat": diff --git a/internal/stackql/queryshape/queryshape.go b/internal/stackql/queryshape/queryshape.go index 6f9b62f2..7a825a83 100644 --- a/internal/stackql/queryshape/queryshape.go +++ b/internal/stackql/queryshape/queryshape.go @@ -18,6 +18,9 @@ package queryshape import ( "math" + "regexp" + "strconv" + "strings" "github.com/lib/pq/oid" "github.com/stackql/psql-wire/pkg/sqldata" @@ -179,3 +182,50 @@ func columnMetadataToSQLColumns(cols []typing.ColumnMetadata) []sqldata.ISQLColu } return result } + +var paramPlaceholderRegex = regexp.MustCompile(`\$(\d+)`) + +// SubstituteParams replaces $1, $2, ... placeholders with their bound values. +// NULL parameters (nil entries in paramValues) are substituted as the literal NULL. +// String values are single-quote escaped. +func SubstituteParams(query string, paramFormats []int16, paramValues [][]byte) string { + if len(paramValues) == 0 { + return query + } + return paramPlaceholderRegex.ReplaceAllStringFunc(query, func(match string) string { + idxStr := match[1:] // strip leading $ + idx, err := strconv.Atoi(idxStr) + if err != nil || idx < 1 || idx > len(paramValues) { + return match // leave unrecognised placeholders as-is + } + val := paramValues[idx-1] + if val == nil { + return "NULL" + } + text := string(val) + escaped := strings.ReplaceAll(text, "'", "''") + return "'" + escaped + "'" + }) +} + +// SubstituteDecodedParams replaces $1, $2, ... placeholders with +// pre-decoded string values. "NULL" values are substituted unquoted; +// all other values are single-quote escaped. +func SubstituteDecodedParams(query string, decodedValues []string) string { + if len(decodedValues) == 0 { + return query + } + return paramPlaceholderRegex.ReplaceAllStringFunc(query, func(match string) string { + idxStr := match[1:] + idx, err := strconv.Atoi(idxStr) + if err != nil || idx < 1 || idx > len(decodedValues) { + return match + } + val := decodedValues[idx-1] + if val == "NULL" { + return "NULL" + } + escaped := strings.ReplaceAll(val, "'", "''") + return "'" + escaped + "'" + }) +} diff --git a/internal/stackql/queryshape/queryshape_test.go b/internal/stackql/queryshape/queryshape_test.go new file mode 100644 index 00000000..38facb10 --- /dev/null +++ b/internal/stackql/queryshape/queryshape_test.go @@ -0,0 +1,70 @@ +package queryshape + +import ( + "encoding/json" + "os" + "testing" +) + +type extractTableCase struct { + Description string `json:"description"` + Query string `json:"query"` + Expected string `json:"expected"` +} + +func TestExtractSingleTableName(t *testing.T) { + data, err := os.ReadFile("testdata/extract_table_cases.json") + if err != nil { + t.Fatalf("failed to read testdata: %v", err) + } + var cases []extractTableCase + if err := json.Unmarshal(data, &cases); err != nil { + t.Fatalf("failed to parse testdata: %v", err) + } + for _, tc := range cases { + t.Run(tc.Description, func(t *testing.T) { + got := extractSingleTableName(tc.Query) + if got != tc.Expected { + t.Errorf("extractSingleTableName(%q) = %q, want %q", tc.Query, got, tc.Expected) + } + }) + } +} + +type SubstituteParamsCase struct { + Description string `json:"description"` + Query string `json:"query"` + ParamValues []*string `json:"paramValues"` // nil entries represent SQL NULL + Expected string `json:"expected"` +} + +func (c *SubstituteParamsCase) toByteSlices() [][]byte { + result := make([][]byte, len(c.ParamValues)) + for i, v := range c.ParamValues { + if v == nil { + result[i] = nil + } else { + result[i] = []byte(*v) + } + } + return result +} + +func TestSubstituteParams(t *testing.T) { + data, err := os.ReadFile("testdata/substitute_params_cases.json") + if err != nil { + t.Fatalf("failed to read testdata: %v", err) + } + var cases []SubstituteParamsCase + if err := json.Unmarshal(data, &cases); err != nil { + t.Fatalf("failed to parse testdata: %v", err) + } + for _, tc := range cases { + t.Run(tc.Description, func(t *testing.T) { + got := SubstituteParams(tc.Query, nil, tc.toByteSlices()) + if got != tc.Expected { + t.Errorf("SubstituteParams(%q, ...) = %q, want %q", tc.Query, got, tc.Expected) + } + }) + } +} diff --git a/internal/stackql/queryshape/testdata/extract_table_cases.json b/internal/stackql/queryshape/testdata/extract_table_cases.json new file mode 100644 index 00000000..07410f73 --- /dev/null +++ b/internal/stackql/queryshape/testdata/extract_table_cases.json @@ -0,0 +1,82 @@ +[ + { + "description": "simple unqualified SELECT", + "query": "SELECT * FROM my_table", + "expected": "my_table" + }, + { + "description": "SELECT with WHERE clause", + "query": "SELECT name FROM my_view WHERE x = 1", + "expected": "my_view" + }, + { + "description": "SELECT with alias", + "query": "SELECT t.name FROM my_table t WHERE t.id = 5", + "expected": "my_table" + }, + { + "description": "JOIN returns empty (multi-table)", + "query": "SELECT * FROM a JOIN b ON a.id = b.id", + "expected": "" + }, + { + "description": "subquery returns empty", + "query": "SELECT * FROM (SELECT 1) sq", + "expected": "" + }, + { + "description": "INSERT returns empty", + "query": "INSERT INTO foo VALUES (1)", + "expected": "" + }, + { + "description": "CREATE VIEW returns empty", + "query": "CREATE VIEW v AS SELECT 1", + "expected": "" + }, + { + "description": "qualified provider.service.resource returns empty", + "query": "SELECT * FROM google.compute.instances WHERE project = 'x'", + "expected": "" + }, + { + "description": "qualified local_templated provider returns empty", + "query": "SELECT * FROM local_openssl.keys.x509 WHERE cert_file = 'foo.pem'", + "expected": "" + }, + { + "description": "SELECT with ORDER BY", + "query": "SELECT name, url FROM mv_repos ORDER BY name DESC", + "expected": "mv_repos" + }, + { + "description": "SELECT with dollar param in WHERE", + "query": "SELECT name FROM my_table WHERE id = $1", + "expected": "my_table" + }, + { + "description": "SELECT with multiple dollar params", + "query": "SELECT name FROM my_table WHERE id = $1 AND region = $2", + "expected": "my_table" + }, + { + "description": "LEFT OUTER JOIN returns empty", + "query": "SELECT v1.name FROM vw_repos v1 LEFT OUTER JOIN mv_repos mv ON v1.name = mv.name", + "expected": "" + }, + { + "description": "CTE extracts inner table name (harmless; won't match stored relations)", + "query": "WITH cte AS (SELECT 1) SELECT * FROM cte", + "expected": "cte" + }, + { + "description": "empty string returns empty", + "query": "", + "expected": "" + }, + { + "description": "garbage returns empty", + "query": "NOT VALID SQL AT ALL", + "expected": "" + } +] diff --git a/internal/stackql/queryshape/testdata/substitute_params_cases.json b/internal/stackql/queryshape/testdata/substitute_params_cases.json new file mode 100644 index 00000000..b64a382a --- /dev/null +++ b/internal/stackql/queryshape/testdata/substitute_params_cases.json @@ -0,0 +1,68 @@ +[ + { + "description": "no params returns query unchanged", + "query": "SELECT name FROM my_table", + "paramValues": [], + "expected": "SELECT name FROM my_table" + }, + { + "description": "single text param", + "query": "SELECT * FROM t WHERE name = $1", + "paramValues": ["hello"], + "expected": "SELECT * FROM t WHERE name = 'hello'" + }, + { + "description": "multiple params", + "query": "SELECT * FROM t WHERE name = $1 AND region = $2", + "paramValues": ["hello", "us-east-1"], + "expected": "SELECT * FROM t WHERE name = 'hello' AND region = 'us-east-1'" + }, + { + "description": "NULL param", + "query": "SELECT * FROM t WHERE name = $1", + "paramValues": [null], + "expected": "SELECT * FROM t WHERE name = NULL" + }, + { + "description": "mixed NULL and text", + "query": "SELECT * FROM t WHERE a = $1 AND b = $2", + "paramValues": ["val", null], + "expected": "SELECT * FROM t WHERE a = 'val' AND b = NULL" + }, + { + "description": "numeric param as text", + "query": "SELECT * FROM t WHERE id = $1", + "paramValues": ["42"], + "expected": "SELECT * FROM t WHERE id = '42'" + }, + { + "description": "param with embedded single quote", + "query": "SELECT * FROM t WHERE name = $1", + "paramValues": ["O'Brien"], + "expected": "SELECT * FROM t WHERE name = 'O''Brien'" + }, + { + "description": "param index out of range left as-is", + "query": "SELECT * FROM t WHERE a = $1 AND b = $3", + "paramValues": ["val"], + "expected": "SELECT * FROM t WHERE a = 'val' AND b = $3" + }, + { + "description": "SELECT literal expression with params", + "query": "SELECT $1::text, $2::int", + "paramValues": ["hello", "42"], + "expected": "SELECT 'hello'::text, '42'::int" + }, + { + "description": "no placeholders returns query unchanged even with params", + "query": "SELECT 1", + "paramValues": ["unused"], + "expected": "SELECT 1" + }, + { + "description": "empty string param", + "query": "SELECT * FROM t WHERE name = $1", + "paramValues": [""], + "expected": "SELECT * FROM t WHERE name = ''" + } +] diff --git a/internal/stackql/typing/oid_mapping_test.go b/internal/stackql/typing/oid_mapping_test.go new file mode 100644 index 00000000..d0222003 --- /dev/null +++ b/internal/stackql/typing/oid_mapping_test.go @@ -0,0 +1,50 @@ +package typing + +import ( + "testing" + + "github.com/lib/pq/oid" + "github.com/stackql/stackql-parser/go/vt/sqlparser" +) + +func TestGetOidForParserColType(t *testing.T) { + tests := []struct { + colType string + expected oid.Oid + }{ + {"int", oid.T_numeric}, + {"integer", oid.T_numeric}, + {"int4", oid.T_numeric}, + {"int8", oid.T_numeric}, + {"bigint", oid.T_numeric}, + {"numeric", oid.T_numeric}, + {"decimal", oid.T_numeric}, + {"float", oid.T_numeric}, + {"float8", oid.T_numeric}, + {"double precision", oid.T_numeric}, + {"bool", oid.T_bool}, + {"boolean", oid.T_bool}, + {"text", oid.T_text}, + {"varchar", oid.T_text}, + {"string", oid.T_text}, + {"timestamp", oid.T_timestamp}, + {"json", oid.T_text}, + {"jsonb", oid.T_text}, + {"uuid", oid.T_text}, + } + for _, tt := range tests { + t.Run(tt.colType, func(t *testing.T) { + got := GetOidForParserColType(sqlparser.ColumnType{Type: tt.colType}) + if got != tt.expected { + t.Errorf("GetOidForParserColType(%q) = %d, want %d", tt.colType, got, tt.expected) + } + }) + } +} + +func TestGetOidForSchemaNil(t *testing.T) { + got := GetOidForSchema(nil) + if got != oid.T_text { + t.Errorf("GetOidForSchema(nil) = %d, want %d (T_text)", got, oid.T_text) + } +} diff --git a/internal/stackql/typing/relayed_column_metadata.go b/internal/stackql/typing/relayed_column_metadata.go index 4a542bd8..67b1b55a 100644 --- a/internal/stackql/typing/relayed_column_metadata.go +++ b/internal/stackql/typing/relayed_column_metadata.go @@ -53,8 +53,6 @@ func (cd *relayedColumnMetadata) getOidForRelationalType(relType string) oid.Oid switch relType { case "object", "array", "text": return oid.T_text - // case "integer": - // return oid.T_numeric case "boolean", "bool": return oid.T_text case "number", "decimal", "numeric", "real": diff --git a/internal/stackql/typing/standard_column_metadata.go b/internal/stackql/typing/standard_column_metadata.go index 2c0a7285..f7b18dcf 100644 --- a/internal/stackql/typing/standard_column_metadata.go +++ b/internal/stackql/typing/standard_column_metadata.go @@ -81,8 +81,6 @@ func getOidForSchema(colSchema formulation.Schema) oid.Oid { switch colSchema.GetType() { case "object", "array": return oid.T_text - // case "integer": - // return oid.T_numeric case "boolean", "bool": return oid.T_text case "number": From 007ace5fef7e096a8f902854ce245986d2a94deb Mon Sep 17 00:00:00 2001 From: General Kroll Date: Sat, 4 Apr 2026 00:39:53 +1100 Subject: [PATCH 08/11] stupid-linting --- internal/stackql/driver/driver.go | 8 +- internal/stackql/paramdecoder/paramdecoder.go | 52 +++++++----- .../stackql/paramdecoder/paramdecoder_test.go | 19 ++--- internal/stackql/psqlwire/psqlwire.go | 81 ------------------- internal/stackql/queryshape/queryshape.go | 4 +- .../stackql/queryshape/queryshape_test.go | 8 +- internal/stackql/typing/oid_mapping_test.go | 2 +- 7 files changed, 55 insertions(+), 119 deletions(-) diff --git a/internal/stackql/driver/driver.go b/internal/stackql/driver/driver.go index d1bea81a..f7fb090d 100644 --- a/internal/stackql/driver/driver.go +++ b/internal/stackql/driver/driver.go @@ -256,8 +256,8 @@ func (dr *basicStackQLDriver) HandleDescribeStatement( func (dr *basicStackQLDriver) HandleDescribePortal( ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32, ) ([]sqldata.ISQLColumn, error) { - if portal, ok := dr.portalCache[portalName]; ok { - if cached, ok := dr.stmtCache[portal.stmtName]; ok { + if portal, portalFound := dr.portalCache[portalName]; portalFound { + if cached, stmtFound := dr.stmtCache[portal.stmtName]; stmtFound { return cached.columns, nil } } @@ -270,8 +270,8 @@ func (dr *basicStackQLDriver) HandleExecute( ) (sqldata.ISQLResultStream, error) { // Look up cached param OIDs for format-aware decoding. var paramOIDs []uint32 - if portal, ok := dr.portalCache[portalName]; ok { - if cached, ok := dr.stmtCache[portal.stmtName]; ok { + if portal, portalFound := dr.portalCache[portalName]; portalFound { + if cached, stmtFound := dr.stmtCache[portal.stmtName]; stmtFound { paramOIDs = cached.paramOIDs } } diff --git a/internal/stackql/paramdecoder/paramdecoder.go b/internal/stackql/paramdecoder/paramdecoder.go index 07c1eda7..fcebe7b3 100644 --- a/internal/stackql/paramdecoder/paramdecoder.go +++ b/internal/stackql/paramdecoder/paramdecoder.go @@ -75,52 +75,66 @@ func decodeParam(paramOID oid.Oid, format int16, val []byte) (string, error) { return decodeBinary(paramOID, val) } +// Binary wire sizes for fixed-width postgres types. +const ( + boolSize = 1 + int2Size = 2 + int4Size = 4 + int8Size = 8 + float4Size = 4 + float8Size = 8 + timestampSize = 8 +) + // decodeBinary decodes a binary-encoded parameter value to its string representation. // -//nolint:cyclop // switch over OIDs is inherently branchy +//nolint:cyclop,exhaustive // switch over OIDs is inherently branchy; only common types handled func decodeBinary(paramOID oid.Oid, val []byte) (string, error) { switch paramOID { case oid.T_bool: - if len(val) != 1 { - return "", fmt.Errorf("bool: expected 1 byte, got %d", len(val)) + if len(val) != boolSize { + return "", fmt.Errorf("bool: expected %d byte, got %d", boolSize, len(val)) } if val[0] != 0 { return "true", nil } return "false", nil case oid.T_int2: - if len(val) != 2 { - return "", fmt.Errorf("int2: expected 2 bytes, got %d", len(val)) + if len(val) != int2Size { + return "", fmt.Errorf("int2: expected %d bytes, got %d", int2Size, len(val)) } - return strconv.FormatInt(int64(int16(binary.BigEndian.Uint16(val))), 10), nil //nolint:gosec // deliberate int16 conversion + v := int16(binary.BigEndian.Uint16(val)) //nolint:gosec // deliberate narrowing + return strconv.FormatInt(int64(v), 10), nil case oid.T_int4: - if len(val) != 4 { - return "", fmt.Errorf("int4: expected 4 bytes, got %d", len(val)) + if len(val) != int4Size { + return "", fmt.Errorf("int4: expected %d bytes, got %d", int4Size, len(val)) } - return strconv.FormatInt(int64(int32(binary.BigEndian.Uint32(val))), 10), nil //nolint:gosec // deliberate int32 conversion + v := int32(binary.BigEndian.Uint32(val)) //nolint:gosec // deliberate narrowing + return strconv.FormatInt(int64(v), 10), nil case oid.T_int8: - if len(val) != 8 { - return "", fmt.Errorf("int8: expected 8 bytes, got %d", len(val)) + if len(val) != int8Size { + return "", fmt.Errorf("int8: expected %d bytes, got %d", int8Size, len(val)) } - return strconv.FormatInt(int64(binary.BigEndian.Uint64(val)), 10), nil + v := int64(binary.BigEndian.Uint64(val)) //nolint:gosec // deliberate conversion + return strconv.FormatInt(v, 10), nil case oid.T_float4: - if len(val) != 4 { - return "", fmt.Errorf("float4: expected 4 bytes, got %d", len(val)) + if len(val) != float4Size { + return "", fmt.Errorf("float4: expected %d bytes, got %d", float4Size, len(val)) } bits := binary.BigEndian.Uint32(val) return strconv.FormatFloat(float64(math.Float32frombits(bits)), 'f', -1, 32), nil case oid.T_float8: - if len(val) != 8 { - return "", fmt.Errorf("float8: expected 8 bytes, got %d", len(val)) + if len(val) != float8Size { + return "", fmt.Errorf("float8: expected %d bytes, got %d", float8Size, len(val)) } bits := binary.BigEndian.Uint64(val) return strconv.FormatFloat(math.Float64frombits(bits), 'f', -1, 64), nil case oid.T_timestamp, oid.T_timestamptz: - if len(val) != 8 { - return "", fmt.Errorf("timestamp: expected 8 bytes, got %d", len(val)) + if len(val) != timestampSize { + return "", fmt.Errorf("timestamp: expected %d bytes, got %d", timestampSize, len(val)) } // Postgres binary timestamp: microseconds since 2000-01-01 00:00:00 UTC. - microseconds := int64(binary.BigEndian.Uint64(val)) + microseconds := int64(binary.BigEndian.Uint64(val)) //nolint:gosec // deliberate conversion pgEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) ts := pgEpoch.Add(time.Duration(microseconds) * time.Microsecond) return ts.Format("2006-01-02 15:04:05.999999"), nil diff --git a/internal/stackql/paramdecoder/paramdecoder_test.go b/internal/stackql/paramdecoder/paramdecoder_test.go index 438eca50..edec1381 100644 --- a/internal/stackql/paramdecoder/paramdecoder_test.go +++ b/internal/stackql/paramdecoder/paramdecoder_test.go @@ -1,4 +1,4 @@ -package paramdecoder +package paramdecoder_test import ( "encoding/binary" @@ -6,10 +6,11 @@ import ( "testing" "github.com/lib/pq/oid" + "github.com/stackql/stackql/internal/stackql/paramdecoder" ) func TestDecodeTextParams(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() results, err := d.DecodeParams( []uint32{uint32(oid.T_text), uint32(oid.T_text)}, []int16{0}, // all text @@ -24,7 +25,7 @@ func TestDecodeTextParams(t *testing.T) { } func TestDecodeNullParam(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() results, err := d.DecodeParams( []uint32{uint32(oid.T_text)}, nil, @@ -39,7 +40,7 @@ func TestDecodeNullParam(t *testing.T) { } func TestDecodeBinaryInt4(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() val := make([]byte, 4) binary.BigEndian.PutUint32(val, uint32(42)) results, err := d.DecodeParams( @@ -56,7 +57,7 @@ func TestDecodeBinaryInt4(t *testing.T) { } func TestDecodeBinaryInt8(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() val := make([]byte, 8) binary.BigEndian.PutUint64(val, uint64(9999999999)) results, err := d.DecodeParams( @@ -73,7 +74,7 @@ func TestDecodeBinaryInt8(t *testing.T) { } func TestDecodeBinaryFloat8(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() val := make([]byte, 8) binary.BigEndian.PutUint64(val, math.Float64bits(3.14)) results, err := d.DecodeParams( @@ -90,7 +91,7 @@ func TestDecodeBinaryFloat8(t *testing.T) { } func TestDecodeBinaryBool(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() results, err := d.DecodeParams( []uint32{uint32(oid.T_bool), uint32(oid.T_bool)}, []int16{1}, @@ -105,7 +106,7 @@ func TestDecodeBinaryBool(t *testing.T) { } func TestDecodeMixedFormats(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() int4Val := make([]byte, 4) binary.BigEndian.PutUint32(int4Val, uint32(100)) results, err := d.DecodeParams( @@ -122,7 +123,7 @@ func TestDecodeMixedFormats(t *testing.T) { } func TestDecodeUnknownOIDBinaryFallsBackToText(t *testing.T) { - d := NewDecoder() + d := paramdecoder.NewDecoder() results, err := d.DecodeParams( []uint32{99999}, []int16{1}, // binary diff --git a/internal/stackql/psqlwire/psqlwire.go b/internal/stackql/psqlwire/psqlwire.go index 1f7b1c64..693251d9 100644 --- a/internal/stackql/psqlwire/psqlwire.go +++ b/internal/stackql/psqlwire/psqlwire.go @@ -128,87 +128,6 @@ func ExtractRowElement(column sqldata.ISQLColumn, src interface{}, ci *pgtype.Co return bb, nil } -// coerceForOID converts a value (often a string or []byte from the RDBMS) -// into the Go type that pgtype expects for the given type name. -// This bridges the gap between sqlite's string-heavy output and pgtype's -// type-specific Set() requirements. -// -//nolint:gocyclo,cyclop // switch over type names is inherently branchy -func coerceForOID(typeName string, val interface{}) interface{} { - if val == nil { - return nil - } - s, isStr := valToString(val) - if !isStr { - return val // already a native Go type; pass through - } - if s == "null" || s == "NULL" || s == "" { - return nil - } - switch typeName { - case "numeric": - f, err := strconv.ParseFloat(s, 64) - if err == nil { - return f - } - return val - case "int2": - i, err := strconv.ParseInt(s, 10, 16) - if err == nil { - return int16(i) - } - return val - case "int4": - i, err := strconv.ParseInt(s, 10, 32) - if err == nil { - return int32(i) - } - return val - case "int8": - i, err := strconv.ParseInt(s, 10, 64) - if err == nil { - return i - } - return val - case "float4": - f, err := strconv.ParseFloat(s, 32) - if err == nil { - return float32(f) - } - return val - case "float8": - f, err := strconv.ParseFloat(s, 64) - if err == nil { - return f - } - return val - case "bool": - switch s { - case "true", "TRUE", "t", "1", "yes": - return true - case "false", "FALSE", "f", "0", "no": - return false - } - return val - case "json", "jsonb": - return s // pass as string; pgtype text encoder handles it - default: - return val // text, varchar, timestamp, etc. — string is fine - } -} - -// valToString extracts a string from string or []byte values. -func valToString(val interface{}) (string, bool) { - switch v := val.(type) { - case string: - return v, true - case []byte: - return string(v), true - default: - return "", false - } -} - func getFormatCode(fc string) (postgreswire.FormatCode, error) { switch fc { case "TextFormat": diff --git a/internal/stackql/queryshape/queryshape.go b/internal/stackql/queryshape/queryshape.go index 7a825a83..52ecf84e 100644 --- a/internal/stackql/queryshape/queryshape.go +++ b/internal/stackql/queryshape/queryshape.go @@ -188,7 +188,9 @@ var paramPlaceholderRegex = regexp.MustCompile(`\$(\d+)`) // SubstituteParams replaces $1, $2, ... placeholders with their bound values. // NULL parameters (nil entries in paramValues) are substituted as the literal NULL. // String values are single-quote escaped. -func SubstituteParams(query string, paramFormats []int16, paramValues [][]byte) string { +// +//nolint:revive // paramFormats retained for future binary format support +func SubstituteParams(query string, paramFormats []int16, paramValues [][]byte) string { //nolint:revive // future use if len(paramValues) == 0 { return query } diff --git a/internal/stackql/queryshape/queryshape_test.go b/internal/stackql/queryshape/queryshape_test.go index 38facb10..ffd9aa76 100644 --- a/internal/stackql/queryshape/queryshape_test.go +++ b/internal/stackql/queryshape/queryshape_test.go @@ -1,4 +1,4 @@ -package queryshape +package queryshape //nolint:testpackage // tests unexported extractSingleTableName import ( "encoding/json" @@ -32,10 +32,10 @@ func TestExtractSingleTableName(t *testing.T) { } type SubstituteParamsCase struct { - Description string `json:"description"` - Query string `json:"query"` + Description string `json:"description"` + Query string `json:"query"` ParamValues []*string `json:"paramValues"` // nil entries represent SQL NULL - Expected string `json:"expected"` + Expected string `json:"expected"` } func (c *SubstituteParamsCase) toByteSlices() [][]byte { diff --git a/internal/stackql/typing/oid_mapping_test.go b/internal/stackql/typing/oid_mapping_test.go index d0222003..cf03b538 100644 --- a/internal/stackql/typing/oid_mapping_test.go +++ b/internal/stackql/typing/oid_mapping_test.go @@ -1,4 +1,4 @@ -package typing +package typing //nolint:testpackage // tests exported functions in same package for simplicity import ( "testing" From db7bb0c48274b41a9f9e3d975770f1dbd848500c Mon Sep 17 00:00:00 2001 From: General Kroll Date: Sat, 4 Apr 2026 00:46:21 +1100 Subject: [PATCH 09/11] needed-lib-changes --- docs/psql_wire_changes.md | 131 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 docs/psql_wire_changes.md diff --git a/docs/psql_wire_changes.md b/docs/psql_wire_changes.md new file mode 100644 index 00000000..e17f9720 --- /dev/null +++ b/docs/psql_wire_changes.md @@ -0,0 +1,131 @@ +# Required psql-wire library changes for extended query fidelity + +These changes are needed in `github.com/stackql/psql-wire` to complete postgres-fidelity extended query support in stackql. + +## 1. Text format encoding must not alter value representation + +**File:** `row.go` — `Column.Write()` method (line ~82) + +**Problem:** When a column has a non-text OID (e.g. `T_int8`, `T_bool`), the current path does: +```go +typed.Value.Set(src) // pgtype.Int8.Set("100000001") — works +encoder := fc.Encoder(typed) // TextEncoder +bb, _ := encoder(ci, nil) // outputs "100000001" (no quotes, different width) +``` + +For `T_bool`, `pgtype.Bool` text-encodes as `"true"/"false"` instead of sqlite's `"t"/"f"`. For `T_int8`, formatting may differ from the string that came in. + +**Fix:** For `TextFormat`, bypass `pgtype.Set()` + encoder when the source value is already a string or `[]byte`. Write the raw bytes directly with a length prefix: + +```go +func (column Column) Write(ctx context.Context, writer buffer.Writer, src interface{}) error { + if column.Format == TextFormat { + if b, ok := asTextBytes(src); ok { + if b == nil { + writer.AddInt32(-1) // NULL + return nil + } + writer.AddInt32(int32(len(b))) + writer.AddBytes(b) + return nil + } + } + // existing pgtype path for binary format or non-string sources + ... +} + +func asTextBytes(src interface{}) ([]byte, bool) { + switch v := src.(type) { + case string: + if strings.ToLower(v) == "null" { + return nil, true + } + return []byte(v), true + case []byte: + return v, true + default: + return nil, false + } +} +``` + +This preserves the exact string representation from the RDBMS while still sending the correct OID in `RowDescription`. Clients see `T_int8` in the column metadata but receive text-encoded values — which is valid postgres behaviour when `FormatCode=0` (text). + +## 2. Respect `resultFormats` from Bind in RowDescription + +**File:** `extended_query.go` — `writeRowDescriptionFromSQLColumns()` (line ~419) + +**Problem:** Format is hardcoded to `TextFormat`: +```go +colz = append(colz, Column{ + ... + Format: TextFormat, // always text +}) +``` + +**Fix:** Accept `resultFormats []int16` (from the portal's Bind message) and apply per the postgres protocol rules: + +```go +func writeRowDescriptionFromSQLColumns( + ctx context.Context, writer buffer.Writer, + columns []sqldata.ISQLColumn, resultFormats []int16, +) error { + var colz Columns + for i, c := range columns { + colz = append(colz, Column{ + Table: c.GetTableId(), + Name: c.GetName(), + AttrNo: c.GetAttrNum(), + Oid: oid.Oid(c.GetObjectID()), + Width: c.GetWidth(), + Format: resolveResultFormat(resultFormats, i), + }) + } + return colz.Define(ctx, writer) +} + +func resolveResultFormat(formats []int16, i int) FormatCode { + if len(formats) == 0 { + return TextFormat + } + if len(formats) == 1 { + return FormatCode(formats[0]) + } + if i < len(formats) { + return FormatCode(formats[i]) + } + return TextFormat +} +``` + +This requires threading `resultFormats` from the portal through to the row description writer. The portal already stores `ResultFormats` — it just needs to be passed through `handleExecute` → `writeSQLResultHeader`. + +## 3. Thread resultFormats through handleExecute + +**File:** `extended_query.go` — `handleExecute()` (line ~248) + +The portal's `ResultFormats` need to reach the data writer so `Column.Write` uses the correct encoder (text vs binary). Currently the `dataWriter` and `writeSQLResultHeader` don't receive format information. + +**Fix:** Add `resultFormats []int16` to the `dataWriter` struct or pass it through the header writing path: + +```go +dw := &dataWriter{ + ctx: ctx, + client: conn, + resultFormats: portal.ResultFormats, +} +``` + +Then in `writeSQLResultHeader`, use `dw.resultFormats` when calling `writeRowDescriptionFromSQLColumns`. + +## Summary + +| Change | File | Risk | Effect | +|--------|------|------|--------| +| Text bypass for string values | `row.go` | Low — only affects text format path, preserves exact string bytes | Enables OID fidelity without changing output format | +| resultFormats in RowDescription | `extended_query.go` | Low — defaults to TextFormat when formats not specified | Clients can request binary results | +| Thread formats through execute | `extended_query.go` | Low — adds field to existing struct | Connects Bind formats to result encoding | + +Change 1 is the critical blocker. Changes 2-3 are enhancements for binary result support (most clients default to text results anyway). + +Once change 1 lands, stackql can enable finer OID mapping (`integer`→`T_int8`, `boolean`→`T_bool`, etc.) without breaking any existing tests. From c38f586154a6a00af49c8b2c98b28ec40db81ddd Mon Sep 17 00:00:00 2001 From: General Kroll Date: Sat, 4 Apr 2026 08:33:03 +1100 Subject: [PATCH 10/11] - Added robot test `PG Extended Query Column Descriptions Available`. --- .../stackql_test_tooling/StackQLInterfaces.py | 9 +++++++++ .../stackql_test_tooling/psycopg2_client.py | 16 ++++++++++++++++ .../functional/stackql_sessions_postgres.robot | 11 +++++++++++ 3 files changed, 36 insertions(+) diff --git a/test/python/stackql_test_tooling/StackQLInterfaces.py b/test/python/stackql_test_tooling/StackQLInterfaces.py index db102cfa..c8662626 100644 --- a/test/python/stackql_test_tooling/StackQLInterfaces.py +++ b/test/python/stackql_test_tooling/StackQLInterfaces.py @@ -725,6 +725,15 @@ def should_sqlalchemy_raw_session_inline_have_length_greater_than_or_equal_to(se return self.should_be_true(len(result) >= expected_length) + @keyword + def should_PG_client_column_descriptions_equal(self, conn_str :str, query :str, expected_descriptions :typing.List[typing.Dict], **kwargs): + """Execute a query via psycopg2 and verify column descriptions (name + type_code OID).""" + client = PsycoPG2Client(conn_str) + result = client.get_column_descriptions(query) + self.log(f"Column descriptions: {result}") + return self.lists_should_be_equal(result, expected_descriptions) + + @keyword def should_PG_client_V2_session_inline_equal(self, conn_str :str, queries :typing.List[str], expected_output :typing.List[typing.Dict], **kwargs): client = PsycoPG2Client(conn_str) diff --git a/test/python/stackql_test_tooling/psycopg2_client.py b/test/python/stackql_test_tooling/psycopg2_client.py index 82a25a65..7dc8730c 100644 --- a/test/python/stackql_test_tooling/psycopg2_client.py +++ b/test/python/stackql_test_tooling/psycopg2_client.py @@ -37,3 +37,19 @@ def _run_queries(self, queries :typing.List[str]) -> typing.List[typing.Dict]: def run_queries(self, queries :typing.List[str]) -> typing.List[typing.Dict]: return self._run_queries(queries) + + def get_column_descriptions(self, query :str) -> typing.List[typing.Dict]: + """Execute a query and return column metadata from cursor.description. + + Each entry is a dict with keys: name, type_code. + type_code is the PostgreSQL OID for the column type. + """ + with self._connection.cursor() as cur: + cur.execute(query) + if cur.description is None: + return [] + return [ + {'name': col.name, 'type_code': col.type_code} + for col in cur.description + ] + diff --git a/test/robot/functional/stackql_sessions_postgres.robot b/test/robot/functional/stackql_sessions_postgres.robot index 0be1a29b..f17f7f95 100644 --- a/test/robot/functional/stackql_sessions_postgres.robot +++ b/test/robot/functional/stackql_sessions_postgres.robot @@ -40,6 +40,17 @@ SQLAlchemy Session Postgres Intel Views Exist ... stdout=${CURDIR}/tmp/SQLAlchemy-Session-Postgres-Intel-Views-Exist.tmp [Teardown] NONE +PG Extended Query Column Descriptions Available + Pass Execution If "${SQL_BACKEND}" != "postgres_tcp" This is a postgres only test + ${expectedDescriptions} = Evaluate + ... [{'name': 'name', 'type_code': 25}, {'name': 'url', 'type_code': 25}] + Should PG Client Column Descriptions Equal + ... ${POSTGRES_URL_UNENCRYPTED_CONN} + ... select name, url from stackql_repositories order by name + ... ${expectedDescriptions} + ... stdout=${CURDIR}/tmp/PG-Extended-Query-Column-Descriptions-Available.tmp + [Teardown] NONE + SQLAlchemy Session Materialized View Lifecycle Pass Execution If "${SQL_BACKEND}" != "postgres_tcp" This is a postgres only test ${inputStr} = Catenate From 922331a7a915ed54a95958290bd80d5f1a0f5e64 Mon Sep 17 00:00:00 2001 From: General Kroll Date: Sat, 4 Apr 2026 09:38:09 +1100 Subject: [PATCH 11/11] - Added robot test `PG Extended Query Column Descriptions Available`. - Added robot test `PG Extended Query Prepared Statement Returns Rows`. - Added robot test `PG Extended Query Prepared Statement NULL Param Returns Zero Rows`. --- .../stackql_test_tooling/StackQLInterfaces.py | 20 +++++++++++++++++ .../stackql_test_tooling/psycopg2_client.py | 13 +++++++++++ .../stackql_sessions_postgres.robot | 22 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/test/python/stackql_test_tooling/StackQLInterfaces.py b/test/python/stackql_test_tooling/StackQLInterfaces.py index c8662626..1e18fd60 100644 --- a/test/python/stackql_test_tooling/StackQLInterfaces.py +++ b/test/python/stackql_test_tooling/StackQLInterfaces.py @@ -734,6 +734,26 @@ def should_PG_client_column_descriptions_equal(self, conn_str :str, query :str, return self.lists_should_be_equal(result, expected_descriptions) + @keyword + def should_PG_client_prepared_query_results_contain(self, conn_str :str, query :str, params :typing.Tuple, expected_value :str, **kwargs): + """Execute a parameterised query via psycopg2 (extended query protocol) and verify results contain a value.""" + client = PsycoPG2Client(conn_str) + result = client.exec_prepared_query(query, params) + self.log(f"Prepared query results: {result}") + result_str = str(result) + if expected_value not in result_str: + raise AssertionError(f"Expected '{expected_value}' in results but got: {result_str}") + + + @keyword + def should_PG_client_prepared_query_results_have_length(self, conn_str :str, query :str, params :typing.Tuple, expected_length :int, **kwargs): + """Execute a parameterised query via psycopg2 (extended query protocol) and verify result count.""" + client = PsycoPG2Client(conn_str) + result = client.exec_prepared_query(query, params) + self.log(f"Prepared query results ({len(result)} rows): {result}") + return self.should_be_equal(len(result), expected_length) + + @keyword def should_PG_client_V2_session_inline_equal(self, conn_str :str, queries :typing.List[str], expected_output :typing.List[typing.Dict], **kwargs): client = PsycoPG2Client(conn_str) diff --git a/test/python/stackql_test_tooling/psycopg2_client.py b/test/python/stackql_test_tooling/psycopg2_client.py index 7dc8730c..786225e7 100644 --- a/test/python/stackql_test_tooling/psycopg2_client.py +++ b/test/python/stackql_test_tooling/psycopg2_client.py @@ -53,3 +53,16 @@ def get_column_descriptions(self, query :str) -> typing.List[typing.Dict]: for col in cur.description ] + + def exec_prepared_query(self, query :str, params :tuple) -> typing.List[typing.Dict]: + """Execute a parameterised query (extended query protocol) and return rows as dicts.""" + with self._connection.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(query, params) + rv = [] + try: + for r in cur: + rv.append(dict(r)) + except Exception: + pass + return rv + diff --git a/test/robot/functional/stackql_sessions_postgres.robot b/test/robot/functional/stackql_sessions_postgres.robot index f17f7f95..9346f720 100644 --- a/test/robot/functional/stackql_sessions_postgres.robot +++ b/test/robot/functional/stackql_sessions_postgres.robot @@ -51,6 +51,28 @@ PG Extended Query Column Descriptions Available ... stdout=${CURDIR}/tmp/PG-Extended-Query-Column-Descriptions-Available.tmp [Teardown] NONE +PG Extended Query Prepared Statement Returns Rows + Pass Execution If "${SQL_BACKEND}" != "postgres_tcp" This is a postgres only test + ${params} = Evaluate ('dummyapp.io',) + Should PG Client Prepared Query Results Contain + ... ${POSTGRES_URL_UNENCRYPTED_CONN} + ... SELECT name, url FROM stackql_repositories WHERE name = %s + ... ${params} + ... dummyapp.io + ... stdout=${CURDIR}/tmp/PG-Extended-Query-Prepared-Statement-Returns-Rows.tmp + [Teardown] NONE + +PG Extended Query Prepared Statement NULL Param Returns Zero Rows + Pass Execution If "${SQL_BACKEND}" != "postgres_tcp" This is a postgres only test + ${params} = Evaluate (None,) + Should PG Client Prepared Query Results Have Length + ... ${POSTGRES_URL_UNENCRYPTED_CONN} + ... SELECT name FROM stackql_repositories WHERE name = %s + ... ${params} + ... ${0} + ... stdout=${CURDIR}/tmp/PG-Extended-Query-Prepared-Statement-NULL-Param.tmp + [Teardown] NONE + SQLAlchemy Session Materialized View Lifecycle Pass Execution If "${SQL_BACKEND}" != "postgres_tcp" This is a postgres only test ${inputStr} = Catenate