From b8885a6cbdd26bd3f8a3b6cd676770b9fcb8954f Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Mon, 27 Oct 2025 14:47:01 +0800 Subject: [PATCH 1/7] feat(expr): support vec_elem_avg function Signed-off-by: Alan Tang --- src/common/function/src/scalars/vector.rs | 2 + .../function/src/scalars/vector/elem_avg.rs | 128 ++++++++++++++++++ .../common/function/vector/vector.result | 32 +++++ .../common/function/vector/vector.sql | 8 ++ 4 files changed, 170 insertions(+) create mode 100644 src/common/function/src/scalars/vector/elem_avg.rs diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 75d66f03c53b..f265cfe53ab8 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -14,6 +14,7 @@ mod convert; mod distance; +mod elem_avg; mod elem_product; mod elem_sum; pub mod impl_conv; @@ -64,6 +65,7 @@ impl VectorFunction { registry.register_scalar(vector_subvector::VectorSubvectorFunction::default()); registry.register_scalar(elem_sum::ElemSumFunction::default()); registry.register_scalar(elem_product::ElemProductFunction::default()); + registry.register_scalar(elem_avg::ElemAvgFunction::default()); } } diff --git a/src/common/function/src/scalars/vector/elem_avg.rs b/src/common/function/src/scalars/vector/elem_avg.rs new file mode 100644 index 000000000000..7ebee3ad4128 --- /dev/null +++ b/src/common/function/src/scalars/vector/elem_avg.rs @@ -0,0 +1,128 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; + +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS}; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; +use nalgebra::DVectorView; + +use crate::function::Function; +use crate::scalars::vector::{VectorCalculator, impl_conv}; + +const NAME: &str = "vec_elem_avg"; + +#[derive(Debug, Clone)] +pub(crate) struct ElemAvgFunction { + signature: Signature, +} + +impl Default for ElemAvgFunction { + fn default() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(1, STRINGS.to_vec()), + TypeSignature::Uniform(1, BINARYS.to_vec()), + TypeSignature::Uniform(1, vec![DataType::BinaryView]), + ], + Volatility::Immutable, + ), + } + } +} + +impl Function for ElemAvgFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::Float32) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v0 = + impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean()); + Ok(ScalarValue::Float32(v0)) + }; + + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) + } +} + +impl Display for ElemAvgFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringViewArray; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; + + use super::*; + + #[test] + fn test_elem_avg() { + let func = ElemAvgFunction::default(); + + let input = Arc::new(StringViewArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + None, + ])); + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input.clone())], + arg_fields: vec![], + number_rows: input.len(), + return_field: Arc::new(Field::new("x", DataType::Float32, true)), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let result = result.as_primitive::(); + + assert_eq!(result.len(), 4); + assert_eq!(result.value(0), 2.0); + assert_eq!(result.value(1), 5.0); + assert_eq!(result.value(2), 8.0); + assert!(result.is_null(3)); + } +} diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 57a37d638d08..006467cab1f5 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -150,6 +150,38 @@ SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]')); | -6.0 | +-----------------------------------------------------+ +SELECT vec_elem_avg('[1.0, 2.0, 3.0]'); + ++---------------------------------------+ +| vec_elem_avg(Utf8("[1.0, 2.0, 3.0]")) | ++---------------------------------------+ +| 2.0 | ++---------------------------------------+ + +SELECT vec_elem_avg('[-1.0, -2.0, -3.0]'); + ++------------------------------------------+ +| vec_elem_avg(Utf8("[-1.0, -2.0, -3.0]")) | ++------------------------------------------+ +| -2.0 | ++------------------------------------------+ + +SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]')); + ++--------------------------------------------------+ +| vec_elem_avg(parse_vec(Utf8("[1.0, 2.0, 3.0]"))) | ++--------------------------------------------------+ +| 2.0 | ++--------------------------------------------------+ + +SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]')); + ++-----------------------------------------------------+ +| vec_elem_avg(parse_vec(Utf8("[-1.0, -2.0, -3.0]"))) | ++-----------------------------------------------------+ +| -2.0 | ++-----------------------------------------------------+ + SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]')); +---------------------------------------------------------------+ diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index c441fc14807c..917665be0636 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -36,6 +36,14 @@ SELECT vec_elem_sum(parse_vec('[1.0, 2.0, 3.0]')); SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]')); +SELECT vec_elem_avg('[1.0, 2.0, 3.0]'); + +SELECT vec_elem_avg('[-1.0, -2.0, -3.0]'); + +SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]')); + +SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]')); + SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]')); SELECT vec_to_string(vec_div(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]')); From 8ef0532f64e847cd78c6b4ed4601a4c288f06ff6 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Mon, 27 Oct 2025 17:24:06 +0800 Subject: [PATCH 2/7] feat: support vec_avg function Signed-off-by: Alan Tang --- src/common/function/src/aggrs/vector.rs | 3 + src/common/function/src/aggrs/vector/avg.rs | 238 ++++++++++++++++++ .../common/function/vector/vector.result | 15 ++ .../common/function/vector/vector.sql | 9 + 4 files changed, 265 insertions(+) create mode 100644 src/common/function/src/aggrs/vector/avg.rs diff --git a/src/common/function/src/aggrs/vector.rs b/src/common/function/src/aggrs/vector.rs index 5af064d002fe..03489a51d446 100644 --- a/src/common/function/src/aggrs/vector.rs +++ b/src/common/function/src/aggrs/vector.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::aggrs::vector::avg::VectorAvg; use crate::aggrs::vector::product::VectorProduct; use crate::aggrs::vector::sum::VectorSum; use crate::function_registry::FunctionRegistry; +mod avg; mod product; mod sum; @@ -25,5 +27,6 @@ impl VectorFunction { pub fn register(registry: &FunctionRegistry) { registry.register_aggr(VectorSum::uadf_impl()); registry.register_aggr(VectorProduct::uadf_impl()); + registry.register_aggr(VectorAvg::uadf_impl()); } } diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs new file mode 100644 index 000000000000..2c87c8c5d4b3 --- /dev/null +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -0,0 +1,238 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray}; +use arrow_schema::{DataType, Field}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility, +}; +use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; +use nalgebra::{Const, DVector, DVectorView, Dyn, OVector}; + +use crate::scalars::vector::impl_conv::{ + binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit, +}; + +/// The accumulator for the `vec_avg` aggregate function. +#[derive(Debug, Default)] +pub struct VectorAvg { + avg: Option>, + has_null: bool, +} + +impl VectorAvg { + /// Create a new `AggregateUDF` for the `vec_avg` aggregate function. + pub fn uadf_impl() -> AggregateUDF { + let signature = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Binary]), + ], + Volatility::Immutable, + ); + let udaf = SimpleAggregateUDF::new_with_signature( + "vec_avg", + signature, + DataType::Binary, + Arc::new(Self::accumulator), + vec![Arc::new(Field::new("x", DataType::Binary, true))], + ); + AggregateUDF::from(udaf) + } + + fn accumulator(args: AccumulatorArgs) -> Result> { + if args.schema.fields().len() != 1 { + return Err(datafusion_common::DataFusionError::Internal(format!( + "expect creating `VEC_AVG` with only one input field, actual {}", + args.schema.fields().len() + ))); + } + + let t = args.schema.field(0).data_type(); + if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) { + return Err(datafusion_common::DataFusionError::Internal(format!( + "unexpected input datatype {t} when creating `VEC_AVG`" + ))); + } + + Ok(Box::new(VectorAvg::default())) + } + + fn inner(&mut self, len: usize) -> &mut OVector { + self.avg + .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>)) + } + + fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> { + if values.is_empty() || self.has_null { + return Ok(()); + }; + + let vectors = match values[0].data_type() { + DataType::Utf8 => { + let arr: &StringArray = values[0].as_string(); + arr.iter() + .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into))) + .map(|x| x.map(Cow::Owned)) + .collect::>>()? + } + DataType::LargeUtf8 => { + let arr: &LargeStringArray = values[0].as_string(); + arr.iter() + .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into))) + .map(|x: Result>| x.map(Cow::Owned)) + .collect::>>()? + } + DataType::Binary => { + let arr: &BinaryArray = values[0].as_binary(); + arr.iter() + .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into))) + .collect::>>()? + } + _ => { + return Err(datafusion_common::DataFusionError::NotImplemented(format!( + "unsupported data type {} for `VEC_AVG`", + values[0].data_type() + ))); + } + }; + + if vectors.len() != values[0].len() { + if is_update { + self.has_null = true; + self.avg = None; + } + return Ok(()); + } + + let len = vectors.len(); + let dims = vectors[0].len(); + let mut sum = DVector::zeros(dims); + for v in vectors { + let v_view = DVectorView::from_slice(&v, dims); + sum += &v_view; + } + *self.inner(dims) = sum / (len as f32); + + Ok(()) + } +} + +impl Accumulator for VectorAvg { + fn state(&mut self) -> Result> { + self.evaluate().map(|v| vec![v]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update(values, true) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update(states, false) + } + + fn evaluate(&mut self) -> Result { + match &self.avg { + None => Ok(ScalarValue::Binary(None)), + Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit( + vector.as_slice(), + )))), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringArray; + use datatypes::scalars::ScalarVector; + use datatypes::vectors::{ConstantVector, StringVector, Vector}; + + use super::*; + + #[test] + fn test_update_batch() { + // test update empty batch, expect not updating anything + let mut vec_avg = VectorAvg::default(); + vec_avg.update_batch(&[]).unwrap(); + assert!(vec_avg.avg.is_none()); + assert!(!vec_avg.has_null); + assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + + // test update one not-null value + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))), + vec_avg.evaluate().unwrap() + ); + + // test update one null value + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![Option::::None]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + + // test update no null-value batch + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))), + vec_avg.evaluate().unwrap() + ); + + // test update null-value batch + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + None, + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + + // test update with constant vector + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![ + Arc::new(ConstantVector::new( + Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])), + 4, + )) + .to_arrow_array(), + ]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))), + vec_avg.evaluate().unwrap() + ); + } +} diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 006467cab1f5..c546f9ad25eb 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -301,6 +301,21 @@ FROM ( | [4,5,6] | +---------------------------+ +SELECT vec_to_string(vec_avg(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[10.0, 11.0, 12.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + ++---------------------------+ +| vec_to_string(vec_avg(v)) | ++---------------------------+ +| [5,6,7] | ++---------------------------+ + SELECT vec_to_string(vec_product(v)) FROM ( SELECT '[1.0, 2.0, 3.0]' AS v diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index 917665be0636..9bbf1583f535 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -79,6 +79,15 @@ FROM ( SELECT '[4.0, 5.0, 6.0]' AS v ); +SELECT vec_to_string(vec_avg(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[10.0, 11.0, 12.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + SELECT vec_to_string(vec_product(v)) FROM ( SELECT '[1.0, 2.0, 3.0]' AS v From bb77db0aef30f3c84acd75757298f5cbb050aef7 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Tue, 28 Oct 2025 15:21:24 +0800 Subject: [PATCH 3/7] test: add more query test for avg aggregator Signed-off-by: Alan Tang --- src/query/src/tests.rs | 1 + src/query/src/tests/vec_avg_test.rs | 60 +++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 src/query/src/tests/vec_avg_test.rs diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index c70381d32fc5..4b12464b7307 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -26,6 +26,7 @@ mod query_engine_test; mod time_range_filter_test; mod function; +mod vec_avg_test; mod vec_product_test; mod vec_sum_test; diff --git a/src/query/src/tests/vec_avg_test.rs b/src/query/src/tests/vec_avg_test.rs new file mode 100644 index 000000000000..46bb3528a9bf --- /dev/null +++ b/src/query/src/tests/vec_avg_test.rs @@ -0,0 +1,60 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::AddAssign; + +use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; +use datafusion_common::ScalarValue; +use datatypes::prelude::Value; +use nalgebra::{Const, DVectorView, Dyn, OVector}; + +use crate::tests::{exec_selection, function}; + +#[tokio::test] +async fn test_vec_avg_aggregator() -> Result<(), common_query::error::Error> { + common_telemetry::init_default_ut_logging(); + let engine = function::create_query_engine_for_vector10x3(); + let sql = "select VEC_AVG(vector) as vec_avg from vectors"; + let result = exec_selection(engine.clone(), sql).await; + let value = function::get_value_from_batches("vec_avg", result); + + let mut expected_value = None; + + let sql = "SELECT vector FROM vectors"; + let vectors = exec_selection(engine, sql).await; + + let column = vectors[0].column(0).to_arrow_array(); + let len = column.len(); + for i in 0..column.len() { + let v = ScalarValue::try_from_array(&column, i)?; + let vector = as_veclit(&v)?; + let Some(vector) = vector else { + expected_value = None; + break; + }; + expected_value + .get_or_insert_with(|| OVector::zeros_generic(Dyn(3), Const::<1>)) + .add_assign(&DVectorView::from_slice(&vector, vector.len())); + } + let expected_value = match expected_value.map(|mut v| { + v /= len as f32; + veclit_to_binlit(v.as_slice()) + }) { + None => Value::Null, + Some(bytes) => Value::from(bytes), + }; + assert_eq!(value, expected_value); + + Ok(()) +} From 0ce6c8c3943df5ab841c610c2cee81c12612da24 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Wed, 29 Oct 2025 15:10:04 +0800 Subject: [PATCH 4/7] fix: fix the merge batch mode Signed-off-by: Alan Tang --- src/common/function/src/aggrs/vector/avg.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs index 2c87c8c5d4b3..691253d3a90d 100644 --- a/src/common/function/src/aggrs/vector/avg.rs +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -32,6 +32,7 @@ use crate::scalars::vector::impl_conv::{ #[derive(Debug, Default)] pub struct VectorAvg { avg: Option>, + count: usize, has_null: bool, } @@ -116,6 +117,7 @@ impl VectorAvg { if is_update { self.has_null = true; self.avg = None; + self.count = 0; } return Ok(()); } @@ -127,7 +129,13 @@ impl VectorAvg { let v_view = DVectorView::from_slice(&v, dims); sum += &v_view; } - *self.inner(dims) = sum / (len as f32); + if is_update { + *self.inner(dims) = sum / (len as f32); + } else { + let avg = self.inner(dims).clone(); + *self.inner(dims) = (avg * self.count as f32 + sum) / ((self.count + len) as f32) + } + self.count += len; Ok(()) } From 01f4b5bc891deb84ededda67df1bb174db0fc890 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Fri, 31 Oct 2025 15:37:45 +0800 Subject: [PATCH 5/7] refactor: use sum and count as state for avg function Signed-off-by: Alan Tang --- src/common/function/src/aggrs/vector/avg.rs | 27 +++++++++------------ 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs index 691253d3a90d..021af8cf2342 100644 --- a/src/common/function/src/aggrs/vector/avg.rs +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -31,8 +31,8 @@ use crate::scalars::vector::impl_conv::{ /// The accumulator for the `vec_avg` aggregate function. #[derive(Debug, Default)] pub struct VectorAvg { - avg: Option>, - count: usize, + sum: Option>, + count: u64, has_null: bool, } @@ -42,6 +42,7 @@ impl VectorAvg { let signature = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::Binary]), ], Volatility::Immutable, @@ -75,7 +76,7 @@ impl VectorAvg { } fn inner(&mut self, len: usize) -> &mut OVector { - self.avg + self.sum .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>)) } @@ -116,7 +117,7 @@ impl VectorAvg { if vectors.len() != values[0].len() { if is_update { self.has_null = true; - self.avg = None; + self.sum = None; self.count = 0; } return Ok(()); @@ -129,13 +130,9 @@ impl VectorAvg { let v_view = DVectorView::from_slice(&v, dims); sum += &v_view; } - if is_update { - *self.inner(dims) = sum / (len as f32); - } else { - let avg = self.inner(dims).clone(); - *self.inner(dims) = (avg * self.count as f32 + sum) / ((self.count + len) as f32) - } - self.count += len; + + *self.inner(dims) += sum; + self.count += len as u64; Ok(()) } @@ -155,10 +152,10 @@ impl Accumulator for VectorAvg { } fn evaluate(&mut self) -> Result { - match &self.avg { + match &self.sum { None => Ok(ScalarValue::Binary(None)), - Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit( - vector.as_slice(), + Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit( + (sum / self.count as f32).as_slice(), )))), } } @@ -183,7 +180,7 @@ mod tests { // test update empty batch, expect not updating anything let mut vec_avg = VectorAvg::default(); vec_avg.update_batch(&[]).unwrap(); - assert!(vec_avg.avg.is_none()); + assert!(vec_avg.sum.is_none()); assert!(!vec_avg.has_null); assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); From 02ba91d0eea35314e5d09dcf23a513569fcd0646 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Sat, 1 Nov 2025 10:02:02 +0800 Subject: [PATCH 6/7] refactor: refactor merge batch mode for avg function Signed-off-by: Alan Tang --- src/common/function/src/aggrs/vector/avg.rs | 50 +++++++++++++++------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs index 021af8cf2342..c7573a3d8617 100644 --- a/src/common/function/src/aggrs/vector/avg.rs +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -16,6 +16,8 @@ use std::borrow::Cow; use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray}; +use arrow::compute::sum; +use arrow::datatypes::UInt64Type; use arrow_schema::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ @@ -33,7 +35,6 @@ use crate::scalars::vector::impl_conv::{ pub struct VectorAvg { sum: Option>, count: u64, - has_null: bool, } impl VectorAvg { @@ -52,7 +53,10 @@ impl VectorAvg { signature, DataType::Binary, Arc::new(Self::accumulator), - vec![Arc::new(Field::new("x", DataType::Binary, true))], + vec![ + Arc::new(Field::new("sum", DataType::Binary, true)), + Arc::new(Field::new("count", DataType::UInt64, true)), + ], ); AggregateUDF::from(udaf) } @@ -81,7 +85,7 @@ impl VectorAvg { } fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> { - if values.is_empty() || self.has_null { + if values.is_empty() { return Ok(()); }; @@ -114,16 +118,16 @@ impl VectorAvg { } }; - if vectors.len() != values[0].len() { - if is_update { - self.has_null = true; - self.sum = None; - self.count = 0; - } + if vectors.is_empty() { return Ok(()); } - let len = vectors.len(); + let len = if is_update { + vectors.len() as u64 + } else { + sum(values[1].as_primitive::()).unwrap_or_default() + }; + let dims = vectors[0].len(); let mut sum = DVector::zeros(dims); for v in vectors { @@ -132,7 +136,7 @@ impl VectorAvg { } *self.inner(dims) += sum; - self.count += len as u64; + self.count += len; Ok(()) } @@ -140,7 +144,11 @@ impl VectorAvg { impl Accumulator for VectorAvg { fn state(&mut self) -> Result> { - self.evaluate().map(|v| vec![v]) + let vector = match &self.sum { + None => ScalarValue::Binary(None), + Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))), + }; + Ok(vec![vector, ScalarValue::from(self.count)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -181,7 +189,6 @@ mod tests { let mut vec_avg = VectorAvg::default(); vec_avg.update_batch(&[]).unwrap(); assert!(vec_avg.sum.is_none()); - assert!(!vec_avg.has_null); assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); // test update one not-null value @@ -223,7 +230,22 @@ mod tests { Some("[7.0,8.0,9.0]".to_string()), ]))]; vec_avg.update_batch(&v).unwrap(); - assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))), + vec_avg.evaluate().unwrap() + ); + + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + None, + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))), + vec_avg.evaluate().unwrap() + ); // test update with constant vector let mut vec_avg = VectorAvg::default(); From d1b70804911584227fccb98cf522a39a87ff2c07 Mon Sep 17 00:00:00 2001 From: Alan Tang Date: Thu, 6 Nov 2025 09:56:45 +0800 Subject: [PATCH 7/7] feat: add additional vector restrictions for validation Signed-off-by: Alan Tang --- src/common/function/src/aggrs/vector/avg.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs index c7573a3d8617..ddf1823d2832 100644 --- a/src/common/function/src/aggrs/vector/avg.rs +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -131,6 +131,11 @@ impl VectorAvg { let dims = vectors[0].len(); let mut sum = DVector::zeros(dims); for v in vectors { + if v.len() != dims { + return Err(datafusion_common::DataFusionError::Execution( + "vectors length not match: VEC_AVG".to_string(), + )); + } let v_view = DVectorView::from_slice(&v, dims); sum += &v_view; }