Skip to content

Commit c98f37d

Browse files
committed
refactor: refactor merge batch mode for avg function
Signed-off-by: Alan Tang <[email protected]>
1 parent 59c6836 commit c98f37d

File tree

1 file changed

+36
-14
lines changed
  • src/common/function/src/aggrs/vector

1 file changed

+36
-14
lines changed

src/common/function/src/aggrs/vector/avg.rs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use std::borrow::Cow;
1616
use std::sync::Arc;
1717

1818
use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
19+
use arrow::compute::sum;
20+
use arrow::datatypes::UInt64Type;
1921
use arrow_schema::{DataType, Field};
2022
use datafusion_common::{Result, ScalarValue};
2123
use datafusion_expr::{
@@ -33,7 +35,6 @@ use crate::scalars::vector::impl_conv::{
3335
pub struct VectorAvg {
3436
sum: Option<OVector<f32, Dyn>>,
3537
count: u64,
36-
has_null: bool,
3738
}
3839

3940
impl VectorAvg {
@@ -52,7 +53,10 @@ impl VectorAvg {
5253
signature,
5354
DataType::Binary,
5455
Arc::new(Self::accumulator),
55-
vec![Arc::new(Field::new("x", DataType::Binary, true))],
56+
vec![
57+
Arc::new(Field::new("sum", DataType::Binary, true)),
58+
Arc::new(Field::new("count", DataType::UInt64, true)),
59+
],
5660
);
5761
AggregateUDF::from(udaf)
5862
}
@@ -81,7 +85,7 @@ impl VectorAvg {
8185
}
8286

8387
fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
84-
if values.is_empty() || self.has_null {
88+
if values.is_empty() {
8589
return Ok(());
8690
};
8791

@@ -114,16 +118,16 @@ impl VectorAvg {
114118
}
115119
};
116120

117-
if vectors.len() != values[0].len() {
118-
if is_update {
119-
self.has_null = true;
120-
self.sum = None;
121-
self.count = 0;
122-
}
121+
if vectors.is_empty() {
123122
return Ok(());
124123
}
125124

126-
let len = vectors.len();
125+
let len = if is_update {
126+
vectors.len() as u64
127+
} else {
128+
sum(values[1].as_primitive::<UInt64Type>()).unwrap_or_default()
129+
};
130+
127131
let dims = vectors[0].len();
128132
let mut sum = DVector::zeros(dims);
129133
for v in vectors {
@@ -132,15 +136,19 @@ impl VectorAvg {
132136
}
133137

134138
*self.inner(dims) += sum;
135-
self.count += len as u64;
139+
self.count += len;
136140

137141
Ok(())
138142
}
139143
}
140144

141145
impl Accumulator for VectorAvg {
142146
fn state(&mut self) -> Result<Vec<ScalarValue>> {
143-
self.evaluate().map(|v| vec![v])
147+
let vector = match &self.sum {
148+
None => ScalarValue::Binary(None),
149+
Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))),
150+
};
151+
Ok(vec![vector, ScalarValue::from(self.count)])
144152
}
145153

146154
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
@@ -181,7 +189,6 @@ mod tests {
181189
let mut vec_avg = VectorAvg::default();
182190
vec_avg.update_batch(&[]).unwrap();
183191
assert!(vec_avg.sum.is_none());
184-
assert!(!vec_avg.has_null);
185192
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
186193

187194
// test update one not-null value
@@ -223,7 +230,22 @@ mod tests {
223230
Some("[7.0,8.0,9.0]".to_string()),
224231
]))];
225232
vec_avg.update_batch(&v).unwrap();
226-
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
233+
assert_eq!(
234+
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
235+
vec_avg.evaluate().unwrap()
236+
);
237+
238+
let mut vec_avg = VectorAvg::default();
239+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
240+
None,
241+
Some("[4.0,5.0,6.0]".to_string()),
242+
Some("[7.0,8.0,9.0]".to_string()),
243+
]))];
244+
vec_avg.update_batch(&v).unwrap();
245+
assert_eq!(
246+
ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))),
247+
vec_avg.evaluate().unwrap()
248+
);
227249

228250
// test update with constant vector
229251
let mut vec_avg = VectorAvg::default();

0 commit comments

Comments
 (0)