@@ -16,6 +16,8 @@ use std::borrow::Cow;
1616use std:: sync:: Arc ;
1717
1818use arrow:: array:: { Array , ArrayRef , AsArray , BinaryArray , LargeStringArray , StringArray } ;
19+ use arrow:: compute:: sum;
20+ use arrow:: datatypes:: UInt64Type ;
1921use arrow_schema:: { DataType , Field } ;
2022use datafusion_common:: { Result , ScalarValue } ;
2123use datafusion_expr:: {
@@ -33,7 +35,6 @@ use crate::scalars::vector::impl_conv::{
3335pub struct VectorAvg {
3436 sum : Option < OVector < f32 , Dyn > > ,
3537 count : u64 ,
36- has_null : bool ,
3738}
3839
3940impl VectorAvg {
@@ -52,7 +53,10 @@ impl VectorAvg {
5253 signature,
5354 DataType :: Binary ,
5455 Arc :: new ( Self :: accumulator) ,
55- vec ! [ Arc :: new( Field :: new( "x" , DataType :: Binary , true ) ) ] ,
56+ vec ! [
57+ Arc :: new( Field :: new( "sum" , DataType :: Binary , true ) ) ,
58+ Arc :: new( Field :: new( "count" , DataType :: UInt64 , true ) ) ,
59+ ] ,
5660 ) ;
5761 AggregateUDF :: from ( udaf)
5862 }
@@ -81,7 +85,7 @@ impl VectorAvg {
8185 }
8286
8387 fn update ( & mut self , values : & [ ArrayRef ] , is_update : bool ) -> Result < ( ) > {
84- if values. is_empty ( ) || self . has_null {
88+ if values. is_empty ( ) {
8589 return Ok ( ( ) ) ;
8690 } ;
8791
@@ -114,16 +118,16 @@ impl VectorAvg {
114118 }
115119 } ;
116120
117- if vectors. len ( ) != values[ 0 ] . len ( ) {
118- if is_update {
119- self . has_null = true ;
120- self . sum = None ;
121- self . count = 0 ;
122- }
121+ if vectors. is_empty ( ) {
123122 return Ok ( ( ) ) ;
124123 }
125124
126- let len = vectors. len ( ) ;
125+ let len = if is_update {
126+ vectors. len ( ) as u64
127+ } else {
128+ sum ( values[ 1 ] . as_primitive :: < UInt64Type > ( ) ) . unwrap_or_default ( )
129+ } ;
130+
127131 let dims = vectors[ 0 ] . len ( ) ;
128132 let mut sum = DVector :: zeros ( dims) ;
129133 for v in vectors {
@@ -132,15 +136,19 @@ impl VectorAvg {
132136 }
133137
134138 * self . inner ( dims) += sum;
135- self . count += len as u64 ;
139+ self . count += len;
136140
137141 Ok ( ( ) )
138142 }
139143}
140144
141145impl Accumulator for VectorAvg {
142146 fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
143- self . evaluate ( ) . map ( |v| vec ! [ v] )
147+ let vector = match & self . sum {
148+ None => ScalarValue :: Binary ( None ) ,
149+ Some ( sum) => ScalarValue :: Binary ( Some ( veclit_to_binlit ( sum. as_slice ( ) ) ) ) ,
150+ } ;
151+ Ok ( vec ! [ vector, ScalarValue :: from( self . count) ] )
144152 }
145153
146154 fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
@@ -181,7 +189,6 @@ mod tests {
181189 let mut vec_avg = VectorAvg :: default ( ) ;
182190 vec_avg. update_batch ( & [ ] ) . unwrap ( ) ;
183191 assert ! ( vec_avg. sum. is_none( ) ) ;
184- assert ! ( !vec_avg. has_null) ;
185192 assert_eq ! ( ScalarValue :: Binary ( None ) , vec_avg. evaluate( ) . unwrap( ) ) ;
186193
187194 // test update one not-null value
@@ -223,7 +230,22 @@ mod tests {
223230 Some ( "[7.0,8.0,9.0]" . to_string( ) ) ,
224231 ] ) ) ] ;
225232 vec_avg. update_batch ( & v) . unwrap ( ) ;
226- assert_eq ! ( ScalarValue :: Binary ( None ) , vec_avg. evaluate( ) . unwrap( ) ) ;
233+ assert_eq ! (
234+ ScalarValue :: Binary ( Some ( veclit_to_binlit( & [ 4.0 , 5.0 , 6.0 ] ) ) ) ,
235+ vec_avg. evaluate( ) . unwrap( )
236+ ) ;
237+
238+ let mut vec_avg = VectorAvg :: default ( ) ;
239+ let v: Vec < ArrayRef > = vec ! [ Arc :: new( StringArray :: from( vec![
240+ None ,
241+ Some ( "[4.0,5.0,6.0]" . to_string( ) ) ,
242+ Some ( "[7.0,8.0,9.0]" . to_string( ) ) ,
243+ ] ) ) ] ;
244+ vec_avg. update_batch ( & v) . unwrap ( ) ;
245+ assert_eq ! (
246+ ScalarValue :: Binary ( Some ( veclit_to_binlit( & [ 5.5 , 6.5 , 7.5 ] ) ) ) ,
247+ vec_avg. evaluate( ) . unwrap( )
248+ ) ;
227249
228250 // test update with constant vector
229251 let mut vec_avg = VectorAvg :: default ( ) ;
0 commit comments