@@ -1857,43 +1857,6 @@ impl<'a> TypeChecker<'a> {
18571857 args : & [ & Expr ] ,
18581858 lambda : & Lambda ,
18591859 ) -> Result < Box < ( ScalarExpr , DataType ) > > {
1860- if func_name. starts_with ( "json_" ) && !args. is_empty ( ) {
1861- let target_type = if func_name. starts_with ( "json_array" ) {
1862- TypeName :: Array ( Box :: new ( TypeName :: Nullable ( Box :: new ( TypeName :: Variant ) ) ) )
1863- } else {
1864- TypeName :: Map {
1865- key_type : Box :: new ( TypeName :: String ) ,
1866- val_type : Box :: new ( TypeName :: Nullable ( Box :: new ( TypeName :: Variant ) ) ) ,
1867- }
1868- } ;
1869- let func_name = & func_name[ 5 ..] ;
1870- let mut new_args: Vec < Expr > = args. iter ( ) . map ( |v| ( * v) . to_owned ( ) ) . collect ( ) ;
1871- new_args[ 0 ] = Expr :: Cast {
1872- span : new_args[ 0 ] . span ( ) ,
1873- expr : Box :: new ( new_args[ 0 ] . clone ( ) ) ,
1874- target_type,
1875- pg_style : false ,
1876- } ;
1877-
1878- let args: Vec < & Expr > = new_args. iter ( ) . collect ( ) ;
1879- let result = self . resolve_lambda_function ( span, func_name, & args, lambda) ?;
1880-
1881- let target_type = if result. 1 . is_nullable ( ) {
1882- DataType :: Variant . wrap_nullable ( )
1883- } else {
1884- DataType :: Variant
1885- } ;
1886-
1887- let result_expr = ScalarExpr :: CastExpr ( CastExpr {
1888- span : new_args[ 0 ] . span ( ) ,
1889- is_try : false ,
1890- argument : Box :: new ( result. 0 . clone ( ) ) ,
1891- target_type : Box :: new ( target_type. clone ( ) ) ,
1892- } ) ;
1893-
1894- return Ok ( Box :: new ( ( result_expr, target_type) ) ) ;
1895- }
1896-
18971860 if matches ! (
18981861 self . bind_context. expr_context,
18991862 ExprContext :: InLambdaFunction
@@ -1903,13 +1866,6 @@ impl<'a> TypeChecker<'a> {
19031866 )
19041867 . set_span ( span) ) ;
19051868 }
1906- let params = lambda
1907- . params
1908- . iter ( )
1909- . map ( |param| param. name . to_lowercase ( ) )
1910- . collect :: < Vec < _ > > ( ) ;
1911-
1912- self . check_lambda_param_count ( func_name, params. len ( ) , span) ?;
19131869
19141870 if args. len ( ) != 1 {
19151871 return Err ( ErrorCode :: SemanticError ( format ! (
@@ -1919,7 +1875,46 @@ impl<'a> TypeChecker<'a> {
19191875 ) )
19201876 . set_span ( span) ) ;
19211877 }
1922- let box ( mut arg, arg_type) = self . resolve ( args[ 0 ] ) ?;
1878+ let box ( mut arg, mut arg_type) = self . resolve ( args[ 0 ] ) ?;
1879+
1880+ let mut func_name = func_name;
1881+ let mut is_cast_variant = false ;
1882+ if arg_type. remove_nullable ( ) == DataType :: Variant {
1883+ if func_name. starts_with ( "json_" ) {
1884+ func_name = & func_name[ 5 ..] ;
1885+ }
1886+ // Try auto cast the Variant type to Array(Variant) or Map(String, Variant),
1887+ // so that the lambda functions support variant type as argument.
1888+ let mut target_type = if func_name. starts_with ( "array" ) {
1889+ DataType :: Array ( Box :: new ( DataType :: Nullable ( Box :: new ( DataType :: Variant ) ) ) )
1890+ } else {
1891+ DataType :: Map ( Box :: new ( DataType :: Tuple ( vec ! [
1892+ DataType :: String ,
1893+ DataType :: Nullable ( Box :: new( DataType :: Variant ) ) ,
1894+ ] ) ) )
1895+ } ;
1896+ if arg_type. is_nullable ( ) {
1897+ target_type = target_type. wrap_nullable ( ) ;
1898+ }
1899+
1900+ arg = ScalarExpr :: CastExpr ( CastExpr {
1901+ span : None ,
1902+ is_try : false ,
1903+ argument : Box :: new ( arg. clone ( ) ) ,
1904+ target_type : Box :: new ( target_type. clone ( ) ) ,
1905+ } ) ;
1906+ arg_type = target_type;
1907+
1908+ is_cast_variant = true ;
1909+ }
1910+
1911+ let params = lambda
1912+ . params
1913+ . iter ( )
1914+ . map ( |param| param. name . to_lowercase ( ) )
1915+ . collect :: < Vec < _ > > ( ) ;
1916+
1917+ self . check_lambda_param_count ( func_name, params. len ( ) , span) ?;
19231918
19241919 let inner_ty = match arg_type. remove_nullable ( ) {
19251920 DataType :: Array ( box inner_ty) => inner_ty. clone ( ) ,
@@ -2134,7 +2129,22 @@ impl<'a> TypeChecker<'a> {
21342129 }
21352130 } ;
21362131
2137- Ok ( Box :: new ( ( lambda_func, data_type) ) )
2132+ if is_cast_variant {
2133+ let result_target_type = if data_type. is_nullable ( ) {
2134+ DataType :: Nullable ( Box :: new ( DataType :: Variant ) )
2135+ } else {
2136+ DataType :: Variant
2137+ } ;
2138+ let result_target_scalar = ScalarExpr :: CastExpr ( CastExpr {
2139+ span : None ,
2140+ is_try : false ,
2141+ argument : Box :: new ( lambda_func) ,
2142+ target_type : Box :: new ( result_target_type. clone ( ) ) ,
2143+ } ) ;
2144+ Ok ( Box :: new ( ( result_target_scalar, result_target_type) ) )
2145+ } else {
2146+ Ok ( Box :: new ( ( lambda_func, data_type) ) )
2147+ }
21382148 }
21392149
21402150 fn check_lambda_param_count (
@@ -2768,6 +2778,12 @@ impl<'a> TypeChecker<'a> {
27682778 ) ) ) ;
27692779 }
27702780
2781+ if let Some ( rewritten_func_func) =
2782+ self . try_rewrite_array_function ( span, func_name, & params, & mut args, & mut arg_types)
2783+ {
2784+ return rewritten_func_func;
2785+ }
2786+
27712787 self . resolve_scalar_function_call ( span, func_name, params, args)
27722788 }
27732789
@@ -3641,6 +3657,91 @@ impl<'a> TypeChecker<'a> {
36413657 }
36423658 }
36433659
3660+ fn array_functions ( ) -> & ' static [ Ascii < & ' static str > ] {
3661+ static ARRAY_FUNCTIONS : & [ Ascii < & ' static str > ] = & [
3662+ Ascii :: new ( "array_count" ) ,
3663+ Ascii :: new ( "array_max" ) ,
3664+ Ascii :: new ( "array_min" ) ,
3665+ Ascii :: new ( "array_any" ) ,
3666+ Ascii :: new ( "array_approx_count_distinct" ) ,
3667+ Ascii :: new ( "array_unique" ) ,
3668+ Ascii :: new ( "array_sort_asc_null_first" ) ,
3669+ Ascii :: new ( "array_sort_desc_null_first" ) ,
3670+ Ascii :: new ( "array_sort_asc_null_last" ) ,
3671+ Ascii :: new ( "array_sort_desc_null_last" ) ,
3672+ Ascii :: new ( "array_remove_first" ) ,
3673+ Ascii :: new ( "array_remove_last" ) ,
3674+ Ascii :: new ( "array_distinct" ) ,
3675+ ] ;
3676+ ARRAY_FUNCTIONS
3677+ }
3678+
3679+ fn try_rewrite_array_function (
3680+ & mut self ,
3681+ span : Span ,
3682+ func_name : & str ,
3683+ params : & [ Scalar ] ,
3684+ args : & mut [ ScalarExpr ] ,
3685+ arg_types : & mut [ DataType ] ,
3686+ ) -> Option < Result < Box < ( ScalarExpr , DataType ) > > > {
3687+ // Try auto cast the Variant type to Array(Variant),
3688+ // so that the array functions support Variant type as argument.
3689+ let uni_case_func_name = Ascii :: new ( func_name) ;
3690+ if Self :: array_functions ( ) . contains ( & uni_case_func_name)
3691+ && !arg_types. is_empty ( )
3692+ && arg_types[ 0 ] . remove_nullable ( ) == DataType :: Variant
3693+ {
3694+ let target_type = if arg_types[ 0 ] . is_nullable ( ) {
3695+ DataType :: Nullable ( Box :: new ( DataType :: Array ( Box :: new ( DataType :: Nullable (
3696+ Box :: new ( DataType :: Variant ) ,
3697+ ) ) ) ) )
3698+ } else {
3699+ DataType :: Array ( Box :: new ( DataType :: Nullable ( Box :: new ( DataType :: Variant ) ) ) )
3700+ } ;
3701+ let arg = args[ 0 ] . clone ( ) ;
3702+ args[ 0 ] = ScalarExpr :: CastExpr ( CastExpr {
3703+ span : None ,
3704+ is_try : false ,
3705+ argument : Box :: new ( arg) ,
3706+ target_type : Box :: new ( target_type. clone ( ) ) ,
3707+ } ) ;
3708+ arg_types[ 0 ] = target_type;
3709+
3710+ let result =
3711+ self . resolve_scalar_function_call ( span, func_name, params. to_vec ( ) , args. to_vec ( ) ) ;
3712+ if func_name == "array_remove_first"
3713+ || func_name == "array_remove_last"
3714+ || func_name == "array_distinct"
3715+ || func_name == "array_sort_asc_null_first"
3716+ || func_name == "array_sort_desc_null_first"
3717+ || func_name == "array_sort_asc_null_last"
3718+ || func_name == "array_sort_desc_null_last"
3719+ {
3720+ if result. is_err ( ) {
3721+ return Some ( result) ;
3722+ }
3723+ let box ( result_scalar, result_type) = result. unwrap ( ) ;
3724+
3725+ let result_target_type = if result_type. is_nullable ( ) {
3726+ DataType :: Nullable ( Box :: new ( DataType :: Variant ) )
3727+ } else {
3728+ DataType :: Variant
3729+ } ;
3730+ let result_target_scalar = ScalarExpr :: CastExpr ( CastExpr {
3731+ span : None ,
3732+ is_try : false ,
3733+ argument : Box :: new ( result_scalar) ,
3734+ target_type : Box :: new ( result_target_type. clone ( ) ) ,
3735+ } ) ;
3736+ Some ( Ok ( Box :: new ( ( result_target_scalar, result_target_type) ) ) )
3737+ } else {
3738+ Some ( result)
3739+ }
3740+ } else {
3741+ None
3742+ }
3743+ }
3744+
36443745 fn resolve_trim_function (
36453746 & mut self ,
36463747 span : Span ,
0 commit comments