Skip to content

Commit 6da4456

Browse files
authored
feat: string_agg is compatible with multiple types as arguments instead of just String (#17570)
* feat: `string_agg` is compatible with multiple types as arguments instead of just `String` Signed-off-by: Kould <[email protected]> * chore: add type check for `string_agg` Signed-off-by: Kould <[email protected]> --------- Signed-off-by: Kould <[email protected]>
1 parent 195adc3 commit 6da4456

File tree

3 files changed

+63
-12
lines changed

3 files changed

+63
-12
lines changed

src/query/functions/src/aggregates/aggregate_string_agg.rs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ use databend_common_expression::types::ValueType;
2727
use databend_common_expression::AggrStateRegistry;
2828
use databend_common_expression::AggrStateType;
2929
use databend_common_expression::ColumnBuilder;
30+
use databend_common_expression::DataBlock;
31+
use databend_common_expression::EvaluateOptions;
32+
use databend_common_expression::Evaluator;
33+
use databend_common_expression::FunctionContext;
3034
use databend_common_expression::InputColumns;
3135
use databend_common_expression::Scalar;
36+
use databend_common_expression::Value;
3237

3338
use super::aggregate_function_factory::AggregateFunctionDescription;
3439
use super::borsh_deserialize_state;
@@ -39,6 +44,7 @@ use crate::aggregates::assert_variadic_arguments;
3944
use crate::aggregates::AggrState;
4045
use crate::aggregates::AggrStateLoc;
4146
use crate::aggregates::AggregateFunction;
47+
use crate::BUILTIN_FUNCTIONS;
4248

4349
#[derive(BorshSerialize, BorshDeserialize, Debug)]
4450
pub struct StringAggState {
@@ -49,6 +55,7 @@ pub struct StringAggState {
4955
pub struct AggregateStringAggFunction {
5056
display_name: String,
5157
delimiter: String,
58+
value_type: DataType,
5259
}
5360

5461
impl AggregateFunction for AggregateStringAggFunction {
@@ -77,7 +84,22 @@ impl AggregateFunction for AggregateStringAggFunction {
7784
validity: Option<&Bitmap>,
7885
_input_rows: usize,
7986
) -> Result<()> {
80-
let column = StringType::try_downcast_column(&columns[0]).unwrap();
87+
let column = if self.value_type != DataType::String {
88+
let block = DataBlock::new_from_columns(vec![columns[0].clone()]);
89+
let func_ctx = &FunctionContext::default();
90+
let evaluator = Evaluator::new(&block, func_ctx, &BUILTIN_FUNCTIONS);
91+
let value = evaluator.run_cast(
92+
None,
93+
&self.value_type,
94+
&DataType::String,
95+
Value::Column(columns[0].clone()),
96+
None,
97+
&mut EvaluateOptions::default(),
98+
)?;
99+
StringType::try_downcast_column(value.as_column().unwrap()).unwrap()
100+
} else {
101+
StringType::try_downcast_column(&columns[0]).unwrap()
102+
};
81103
let state = place.get::<StringAggState>();
82104
match validity {
83105
Some(validity) => {
@@ -175,10 +197,15 @@ impl fmt::Display for AggregateStringAggFunction {
175197
}
176198

177199
impl AggregateStringAggFunction {
178-
fn try_create(display_name: &str, delimiter: String) -> Result<Arc<dyn AggregateFunction>> {
200+
fn try_create(
201+
display_name: &str,
202+
delimiter: String,
203+
value_type: DataType,
204+
) -> Result<Arc<dyn AggregateFunction>> {
179205
let func = AggregateStringAggFunction {
180206
display_name: display_name.to_string(),
181207
delimiter,
208+
value_type,
182209
};
183210
Ok(Arc::new(func))
184211
}
@@ -191,19 +218,29 @@ pub fn try_create_aggregate_string_agg_function(
191218
_sort_descs: Vec<AggregateFunctionSortDesc>,
192219
) -> Result<Arc<dyn AggregateFunction>> {
193220
assert_variadic_arguments(display_name, argument_types.len(), (1, 2))?;
194-
// TODO:(b41sh) support other data types
195-
if argument_types[0].remove_nullable() != DataType::String {
221+
let value_type = argument_types[0].remove_nullable();
222+
if !matches!(
223+
value_type,
224+
DataType::Boolean
225+
| DataType::String
226+
| DataType::Number(_)
227+
| DataType::Decimal(_)
228+
| DataType::Timestamp
229+
| DataType::Date
230+
| DataType::Variant
231+
| DataType::Interval
232+
) {
196233
return Err(ErrorCode::BadDataValueType(format!(
197-
"The argument of aggregate function {} must be string",
198-
display_name
234+
"{} does not support type '{:?}'",
235+
display_name, value_type
199236
)));
200237
}
201238
let delimiter = if params.len() == 1 {
202239
params[0].as_string().unwrap().clone()
203240
} else {
204241
String::new()
205242
};
206-
AggregateStringAggFunction::try_create(display_name, delimiter)
243+
AggregateStringAggFunction::try_create(display_name, delimiter, value_type)
207244
}
208245

209246
pub fn aggregate_string_agg_function_desc() -> AggregateFunctionDescription {

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,7 @@ impl<'a> TypeChecker<'a> {
17851785
)));
17861786
}
17871787
let _ = arguments.pop();
1788+
let _ = arg_types.pop();
17881789
let delimiter = delimiter_value.unwrap();
17891790
vec![delimiter.value]
17901791
} else {

tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,14 @@ select group_array_moving_sum(k), group_array_moving_sum(2)(v) from aggr;
354354
[1,3,5,7,9,11,13,15,17,19,21] [10,20,20,20,30,40,45,55,60,60,60]
355355

356356
statement ok
357-
create table t3(s string null, x int null, y int null, z int null);
357+
create table t3(s string null, x int null, y int null, z int null, b boolean null, arr Array(Int32));
358358

359359
statement ok
360360
insert into t3 values
361-
('abc', 3, 1, 2),
362-
('def', 2, 1, null),
363-
(null, 1, 2, 1),
364-
('xyz', 0, 2, 0);
361+
('abc', 3, 1, 2, true, [1,2,3]),
362+
('def', 2, 1, null, false, [4,5,6]),
363+
(null, 1, 2, 1, null, [7,8,9]),
364+
('xyz', 0, 2, 0, true, [10,11,12]);
365365

366366
query T
367367
select string_agg(s) from t3;
@@ -378,6 +378,19 @@ select string_agg(s, '|') from t3;
378378
----
379379
abc|def|xyz
380380

381+
query T
382+
select string_agg(x + 1, '|') from t3;
383+
----
384+
4|3|2|1
385+
386+
query T
387+
select string_agg(b, '|') from t3;
388+
----
389+
true|false|true
390+
391+
statement error
392+
select listagg(arr) from t3;
393+
381394
query T
382395
select listagg(s) from t3;
383396
----

0 commit comments

Comments
 (0)