Skip to content

Commit 855ecc8

Browse files
authored
fix(query): fix lambda function bind column failed (#17402)
* fix(query): fix lambda function bind column failed * fix tests
1 parent ac71ae1 commit 855ecc8

File tree

3 files changed

+76
-27
lines changed

3 files changed

+76
-27
lines changed

src/query/sql/src/planner/expression_parser.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,29 @@ pub fn parse_computed_expr_to_string(
356356

357357
pub fn parse_lambda_expr(
358358
ctx: Arc<dyn TableContext>,
359-
mut bind_context: BindContext,
360-
columns: &[(String, DataType)],
359+
lambda_context: &mut BindContext,
360+
lambda_columns: &[(String, DataType)],
361361
ast: &AExpr,
362362
) -> Result<Box<(ScalarExpr, DataType)>> {
363363
let metadata = Metadata::default();
364-
bind_context.set_expr_context(ExprContext::InLambdaFunction);
364+
lambda_context.set_expr_context(ExprContext::InLambdaFunction);
365365

366-
let column_len = bind_context.all_column_bindings().len();
367-
for (idx, column) in columns.iter().enumerate() {
368-
bind_context.add_column_binding(
366+
// The column index may not be consecutive, and the length of columns
367+
// cannot be used to calculate the column index of the lambda argument.
368+
// We need to start from the current largest column index.
369+
let mut column_index = lambda_context
370+
.all_column_bindings()
371+
.iter()
372+
.map(|c| c.index)
373+
.max()
374+
.unwrap_or_default();
375+
for (lambda_column, lambda_column_type) in lambda_columns.iter() {
376+
column_index += 1;
377+
lambda_context.add_column_binding(
369378
ColumnBindingBuilder::new(
370-
column.0.clone(),
371-
column_len + idx,
372-
Box::new(column.1.clone()),
379+
lambda_column.clone(),
380+
column_index,
381+
Box::new(lambda_column_type.clone()),
373382
Visibility::Visible,
374383
)
375384
.build(),
@@ -379,7 +388,7 @@ pub fn parse_lambda_expr(
379388
let settings = ctx.get_settings();
380389
let name_resolution_ctx = NameResolutionContext::try_from(settings.as_ref())?;
381390
let mut type_checker = TypeChecker::try_create(
382-
&mut bind_context,
391+
lambda_context,
383392
ctx.clone(),
384393
&name_resolution_ctx,
385394
Arc::new(RwLock::new(metadata)),

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
use std::collections::BTreeMap;
1616
use std::collections::HashMap;
17+
use std::collections::HashSet;
1718
use std::collections::VecDeque;
19+
use std::mem;
1820
use std::str::FromStr;
1921
use std::sync::Arc;
2022
use std::vec;
@@ -166,6 +168,7 @@ use crate::BindContext;
166168
use crate::ColumnBinding;
167169
use crate::ColumnBindingBuilder;
168170
use crate::ColumnEntry;
171+
use crate::IndexType;
169172
use crate::MetadataRef;
170173
use crate::Visibility;
171174

@@ -1932,16 +1935,17 @@ impl<'a> TypeChecker<'a> {
19321935
vec![inner_ty.clone()]
19331936
};
19341937

1935-
let columns = params
1938+
let lambda_columns = params
19361939
.iter()
19371940
.zip(inner_tys.iter())
19381941
.map(|(col, ty)| (col.clone(), ty.clone()))
19391942
.collect::<Vec<_>>();
19401943

1944+
let mut lambda_context = self.bind_context.clone();
19411945
let box (lambda_expr, lambda_type) = parse_lambda_expr(
19421946
self.ctx.clone(),
1943-
self.bind_context.clone(),
1944-
&columns,
1947+
&mut lambda_context,
1948+
&lambda_columns,
19451949
&lambda.expr,
19461950
)?;
19471951

@@ -2035,20 +2039,24 @@ impl<'a> TypeChecker<'a> {
20352039
_ => {
20362040
struct LambdaVisitor<'a> {
20372041
bind_context: &'a BindContext,
2042+
arg_index: HashSet<IndexType>,
20382043
args: Vec<ScalarExpr>,
20392044
fields: Vec<DataField>,
20402045
}
20412046

20422047
impl<'a> ScalarVisitor<'a> for LambdaVisitor<'a> {
20432048
fn visit_bound_column_ref(&mut self, col: &'a BoundColumnRef) -> Result<()> {
2044-
let contains = self
2049+
if self.arg_index.contains(&col.column.index) {
2050+
return Ok(());
2051+
}
2052+
self.arg_index.insert(col.column.index);
2053+
let is_outer_column = self
20452054
.bind_context
20462055
.all_column_bindings()
20472056
.iter()
20482057
.map(|c| c.index)
20492058
.contains(&col.column.index);
2050-
// add outer scope columns first
2051-
if contains {
2059+
if is_outer_column {
20522060
let arg = ScalarExpr::BoundColumnRef(col.clone());
20532061
self.args.push(arg);
20542062
let field = DataField::new(
@@ -2061,24 +2069,30 @@ impl<'a> TypeChecker<'a> {
20612069
}
20622070
}
20632071

2072+
// Collect outer scope columns as arguments first.
20642073
let mut lambda_visitor = LambdaVisitor {
20652074
bind_context: self.bind_context,
2075+
arg_index: HashSet::new(),
20662076
args: Vec::new(),
20672077
fields: Vec::new(),
20682078
};
20692079
lambda_visitor.visit(&lambda_expr)?;
20702080

2071-
// add lambda columns at end
2072-
let mut fields = lambda_visitor.fields.clone();
2073-
let column_len = self.bind_context.all_column_bindings().len();
2074-
for (i, inner_ty) in inner_tys.into_iter().enumerate() {
2075-
let lambda_field = DataField::new(&format!("{}", column_len + i), inner_ty);
2076-
fields.push(lambda_field);
2081+
let mut lambda_args = mem::take(&mut lambda_visitor.args);
2082+
lambda_args.push(arg);
2083+
let mut lambda_fields = mem::take(&mut lambda_visitor.fields);
2084+
// Add lambda columns as arguments at end.
2085+
for (lambda_column_name, lambda_column_type) in lambda_columns.into_iter() {
2086+
for column in lambda_context.all_column_bindings().iter().rev() {
2087+
if column.column_name == lambda_column_name {
2088+
let lambda_field =
2089+
DataField::new(&format!("{}", column.index), lambda_column_type);
2090+
lambda_fields.push(lambda_field);
2091+
break;
2092+
}
2093+
}
20772094
}
2078-
let lambda_schema = DataSchema::new(fields);
2079-
let mut args = lambda_visitor.args.clone();
2080-
args.push(arg);
2081-
2095+
let lambda_schema = DataSchema::new(lambda_fields);
20822096
let expr = lambda_expr
20832097
.type_check(&lambda_schema)?
20842098
.project_column_ref(|index| {
@@ -2092,7 +2106,7 @@ impl<'a> TypeChecker<'a> {
20922106
LambdaFunc {
20932107
span,
20942108
func_name: func_name.to_string(),
2095-
args,
2109+
args: lambda_args,
20962110
lambda_expr: Box::new(remote_lambda_expr),
20972111
lambda_display,
20982112
return_type: Box::new(return_type.clone()),

tests/sqllogictests/suites/query/functions/02_0061_function_array.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,32 @@ SELECT arrays_zip(col1, col2) FROM t3;
419419
[(NULL,4)]
420420
[(7,5),(8,5)]
421421

422+
#issue 16794
423+
424+
statement ok
425+
CREATE OR REPLACE TABLE u (id VARCHAR NULL);
426+
427+
statement ok
428+
INSERT INTO u VALUES(1),(2);
429+
430+
statement ok
431+
CREATE OR REPLACE TABLE c (
432+
id VARCHAR NULL,
433+
what_fuck BOOLEAN NOT NULL,
434+
payload VARIANT NULL
435+
);
436+
437+
statement ok
438+
INSERT INTO c VALUES(1, true, '[1,2]'),(1, false, '[3,4]'),(2, true, '123');
439+
440+
query IT
441+
SELECT ids.id, array_filter(array_agg(px.payload), x -> x is not null) AS px_payload
442+
FROM u ids LEFT JOIN c px ON px.id = ids.id
443+
GROUP BY ids.id ORDER BY ids.id;
444+
----
445+
1 ['[1,2]','[3,4]']
446+
2 ['123']
447+
422448
statement ok
423449
USE default
424450

0 commit comments

Comments
 (0)