@@ -27,8 +27,13 @@ use databend_common_expression::types::ValueType;
2727use databend_common_expression:: AggrStateRegistry ;
2828use databend_common_expression:: AggrStateType ;
2929use databend_common_expression:: ColumnBuilder ;
30+ use databend_common_expression:: DataBlock ;
31+ use databend_common_expression:: EvaluateOptions ;
32+ use databend_common_expression:: Evaluator ;
33+ use databend_common_expression:: FunctionContext ;
3034use databend_common_expression:: InputColumns ;
3135use databend_common_expression:: Scalar ;
36+ use databend_common_expression:: Value ;
3237
3338use super :: aggregate_function_factory:: AggregateFunctionDescription ;
3439use super :: borsh_deserialize_state;
@@ -39,6 +44,7 @@ use crate::aggregates::assert_variadic_arguments;
3944use crate :: aggregates:: AggrState ;
4045use crate :: aggregates:: AggrStateLoc ;
4146use crate :: aggregates:: AggregateFunction ;
47+ use crate :: BUILTIN_FUNCTIONS ;
4248
4349#[ derive( BorshSerialize , BorshDeserialize , Debug ) ]
4450pub struct StringAggState {
@@ -49,6 +55,7 @@ pub struct StringAggState {
4955pub struct AggregateStringAggFunction {
5056 display_name : String ,
5157 delimiter : String ,
58+ value_type : DataType ,
5259}
5360
5461impl AggregateFunction for AggregateStringAggFunction {
@@ -77,7 +84,22 @@ impl AggregateFunction for AggregateStringAggFunction {
7784 validity : Option < & Bitmap > ,
7885 _input_rows : usize ,
7986 ) -> Result < ( ) > {
80- let column = StringType :: try_downcast_column ( & columns[ 0 ] ) . unwrap ( ) ;
87+ let column = if self . value_type != DataType :: String {
88+ let block = DataBlock :: new_from_columns ( vec ! [ columns[ 0 ] . clone( ) ] ) ;
89+ let func_ctx = & FunctionContext :: default ( ) ;
90+ let evaluator = Evaluator :: new ( & block, func_ctx, & BUILTIN_FUNCTIONS ) ;
91+ let value = evaluator. run_cast (
92+ None ,
93+ & self . value_type ,
94+ & DataType :: String ,
95+ Value :: Column ( columns[ 0 ] . clone ( ) ) ,
96+ None ,
97+ & mut EvaluateOptions :: default ( ) ,
98+ ) ?;
99+ StringType :: try_downcast_column ( value. as_column ( ) . unwrap ( ) ) . unwrap ( )
100+ } else {
101+ StringType :: try_downcast_column ( & columns[ 0 ] ) . unwrap ( )
102+ } ;
81103 let state = place. get :: < StringAggState > ( ) ;
82104 match validity {
83105 Some ( validity) => {
@@ -175,10 +197,15 @@ impl fmt::Display for AggregateStringAggFunction {
175197}
176198
177199impl AggregateStringAggFunction {
178- fn try_create ( display_name : & str , delimiter : String ) -> Result < Arc < dyn AggregateFunction > > {
200+ fn try_create (
201+ display_name : & str ,
202+ delimiter : String ,
203+ value_type : DataType ,
204+ ) -> Result < Arc < dyn AggregateFunction > > {
179205 let func = AggregateStringAggFunction {
180206 display_name : display_name. to_string ( ) ,
181207 delimiter,
208+ value_type,
182209 } ;
183210 Ok ( Arc :: new ( func) )
184211 }
@@ -191,19 +218,29 @@ pub fn try_create_aggregate_string_agg_function(
191218 _sort_descs : Vec < AggregateFunctionSortDesc > ,
192219) -> Result < Arc < dyn AggregateFunction > > {
193220 assert_variadic_arguments ( display_name, argument_types. len ( ) , ( 1 , 2 ) ) ?;
194- // TODO:(b41sh) support other data types
195- if argument_types[ 0 ] . remove_nullable ( ) != DataType :: String {
221+ let value_type = argument_types[ 0 ] . remove_nullable ( ) ;
222+ if !matches ! (
223+ value_type,
224+ DataType :: Boolean
225+ | DataType :: String
226+ | DataType :: Number ( _)
227+ | DataType :: Decimal ( _)
228+ | DataType :: Timestamp
229+ | DataType :: Date
230+ | DataType :: Variant
231+ | DataType :: Interval
232+ ) {
196233 return Err ( ErrorCode :: BadDataValueType ( format ! (
197- "The argument of aggregate function {} must be string " ,
198- display_name
234+ "{} does not support type '{:?}' " ,
235+ display_name, value_type
199236 ) ) ) ;
200237 }
201238 let delimiter = if params. len ( ) == 1 {
202239 params[ 0 ] . as_string ( ) . unwrap ( ) . clone ( )
203240 } else {
204241 String :: new ( )
205242 } ;
206- AggregateStringAggFunction :: try_create ( display_name, delimiter)
243+ AggregateStringAggFunction :: try_create ( display_name, delimiter, value_type )
207244}
208245
209246pub fn aggregate_string_agg_function_desc ( ) -> AggregateFunctionDescription {
0 commit comments