Skip to content

Commit 910a383

Browse files
feat(expr): support avg functions on vector (#7146)
* feat(expr): support vec_elem_avg function Signed-off-by: Alan Tang <[email protected]> * feat: support vec_avg function Signed-off-by: Alan Tang <[email protected]> * test: add more query test for avg aggregator Signed-off-by: Alan Tang <[email protected]> * fix: fix the merge batch mode Signed-off-by: Alan Tang <[email protected]> * refactor: use sum and count as state for avg function Signed-off-by: Alan Tang <[email protected]> * refactor: refactor merge batch mode for avg function Signed-off-by: Alan Tang <[email protected]> * feat: add additional vector restrictions for validation Signed-off-by: Alan Tang <[email protected]> --------- Signed-off-by: Alan Tang <[email protected]> Co-authored-by: Yingwen <[email protected]>
1 parent af6bbac commit 910a383

File tree

8 files changed

+528
-0
lines changed

8 files changed

+528
-0
lines changed

src/common/function/src/aggrs/vector.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use crate::aggrs::vector::avg::VectorAvg;
1516
use crate::aggrs::vector::product::VectorProduct;
1617
use crate::aggrs::vector::sum::VectorSum;
1718
use crate::function_registry::FunctionRegistry;
1819

20+
mod avg;
1921
mod product;
2022
mod sum;
2123

@@ -25,5 +27,6 @@ impl VectorFunction {
2527
pub fn register(registry: &FunctionRegistry) {
2628
registry.register_aggr(VectorSum::uadf_impl());
2729
registry.register_aggr(VectorProduct::uadf_impl());
30+
registry.register_aggr(VectorAvg::uadf_impl());
2831
}
2932
}
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// Copyright 2023 Greptime Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::borrow::Cow;
16+
use std::sync::Arc;
17+
18+
use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
19+
use arrow::compute::sum;
20+
use arrow::datatypes::UInt64Type;
21+
use arrow_schema::{DataType, Field};
22+
use datafusion_common::{Result, ScalarValue};
23+
use datafusion_expr::{
24+
Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
25+
};
26+
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
27+
use nalgebra::{Const, DVector, DVectorView, Dyn, OVector};
28+
29+
use crate::scalars::vector::impl_conv::{
30+
binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
31+
};
32+
33+
/// The accumulator for the `vec_avg` aggregate function.
34+
#[derive(Debug, Default)]
35+
pub struct VectorAvg {
36+
sum: Option<OVector<f32, Dyn>>,
37+
count: u64,
38+
}
39+
40+
impl VectorAvg {
41+
/// Create a new `AggregateUDF` for the `vec_avg` aggregate function.
42+
pub fn uadf_impl() -> AggregateUDF {
43+
let signature = Signature::one_of(
44+
vec![
45+
TypeSignature::Exact(vec![DataType::Utf8]),
46+
TypeSignature::Exact(vec![DataType::LargeUtf8]),
47+
TypeSignature::Exact(vec![DataType::Binary]),
48+
],
49+
Volatility::Immutable,
50+
);
51+
let udaf = SimpleAggregateUDF::new_with_signature(
52+
"vec_avg",
53+
signature,
54+
DataType::Binary,
55+
Arc::new(Self::accumulator),
56+
vec![
57+
Arc::new(Field::new("sum", DataType::Binary, true)),
58+
Arc::new(Field::new("count", DataType::UInt64, true)),
59+
],
60+
);
61+
AggregateUDF::from(udaf)
62+
}
63+
64+
fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
65+
if args.schema.fields().len() != 1 {
66+
return Err(datafusion_common::DataFusionError::Internal(format!(
67+
"expect creating `VEC_AVG` with only one input field, actual {}",
68+
args.schema.fields().len()
69+
)));
70+
}
71+
72+
let t = args.schema.field(0).data_type();
73+
if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
74+
return Err(datafusion_common::DataFusionError::Internal(format!(
75+
"unexpected input datatype {t} when creating `VEC_AVG`"
76+
)));
77+
}
78+
79+
Ok(Box::new(VectorAvg::default()))
80+
}
81+
82+
fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
83+
self.sum
84+
.get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
85+
}
86+
87+
fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
88+
if values.is_empty() {
89+
return Ok(());
90+
};
91+
92+
let vectors = match values[0].data_type() {
93+
DataType::Utf8 => {
94+
let arr: &StringArray = values[0].as_string();
95+
arr.iter()
96+
.filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
97+
.map(|x| x.map(Cow::Owned))
98+
.collect::<Result<Vec<_>>>()?
99+
}
100+
DataType::LargeUtf8 => {
101+
let arr: &LargeStringArray = values[0].as_string();
102+
arr.iter()
103+
.filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
104+
.map(|x: Result<Vec<f32>>| x.map(Cow::Owned))
105+
.collect::<Result<Vec<_>>>()?
106+
}
107+
DataType::Binary => {
108+
let arr: &BinaryArray = values[0].as_binary();
109+
arr.iter()
110+
.filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
111+
.collect::<Result<Vec<_>>>()?
112+
}
113+
_ => {
114+
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
115+
"unsupported data type {} for `VEC_AVG`",
116+
values[0].data_type()
117+
)));
118+
}
119+
};
120+
121+
if vectors.is_empty() {
122+
return Ok(());
123+
}
124+
125+
let len = if is_update {
126+
vectors.len() as u64
127+
} else {
128+
sum(values[1].as_primitive::<UInt64Type>()).unwrap_or_default()
129+
};
130+
131+
let dims = vectors[0].len();
132+
let mut sum = DVector::zeros(dims);
133+
for v in vectors {
134+
if v.len() != dims {
135+
return Err(datafusion_common::DataFusionError::Execution(
136+
"vectors length not match: VEC_AVG".to_string(),
137+
));
138+
}
139+
let v_view = DVectorView::from_slice(&v, dims);
140+
sum += &v_view;
141+
}
142+
143+
*self.inner(dims) += sum;
144+
self.count += len;
145+
146+
Ok(())
147+
}
148+
}
149+
150+
impl Accumulator for VectorAvg {
151+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
152+
let vector = match &self.sum {
153+
None => ScalarValue::Binary(None),
154+
Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))),
155+
};
156+
Ok(vec![vector, ScalarValue::from(self.count)])
157+
}
158+
159+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
160+
self.update(values, true)
161+
}
162+
163+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
164+
self.update(states, false)
165+
}
166+
167+
fn evaluate(&mut self) -> Result<ScalarValue> {
168+
match &self.sum {
169+
None => Ok(ScalarValue::Binary(None)),
170+
Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
171+
(sum / self.count as f32).as_slice(),
172+
)))),
173+
}
174+
}
175+
176+
fn size(&self) -> usize {
177+
size_of_val(self)
178+
}
179+
}
180+
181+
#[cfg(test)]
182+
mod tests {
183+
use std::sync::Arc;
184+
185+
use arrow::array::StringArray;
186+
use datatypes::scalars::ScalarVector;
187+
use datatypes::vectors::{ConstantVector, StringVector, Vector};
188+
189+
use super::*;
190+
191+
#[test]
192+
fn test_update_batch() {
193+
// test update empty batch, expect not updating anything
194+
let mut vec_avg = VectorAvg::default();
195+
vec_avg.update_batch(&[]).unwrap();
196+
assert!(vec_avg.sum.is_none());
197+
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
198+
199+
// test update one not-null value
200+
let mut vec_avg = VectorAvg::default();
201+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
202+
Some("[1.0,2.0,3.0]".to_string()),
203+
Some("[4.0,5.0,6.0]".to_string()),
204+
]))];
205+
vec_avg.update_batch(&v).unwrap();
206+
assert_eq!(
207+
ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))),
208+
vec_avg.evaluate().unwrap()
209+
);
210+
211+
// test update one null value
212+
let mut vec_avg = VectorAvg::default();
213+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
214+
vec_avg.update_batch(&v).unwrap();
215+
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
216+
217+
// test update no null-value batch
218+
let mut vec_avg = VectorAvg::default();
219+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
220+
Some("[1.0,2.0,3.0]".to_string()),
221+
Some("[4.0,5.0,6.0]".to_string()),
222+
Some("[7.0,8.0,9.0]".to_string()),
223+
]))];
224+
vec_avg.update_batch(&v).unwrap();
225+
assert_eq!(
226+
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
227+
vec_avg.evaluate().unwrap()
228+
);
229+
230+
// test update null-value batch
231+
let mut vec_avg = VectorAvg::default();
232+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
233+
Some("[1.0,2.0,3.0]".to_string()),
234+
None,
235+
Some("[7.0,8.0,9.0]".to_string()),
236+
]))];
237+
vec_avg.update_batch(&v).unwrap();
238+
assert_eq!(
239+
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
240+
vec_avg.evaluate().unwrap()
241+
);
242+
243+
let mut vec_avg = VectorAvg::default();
244+
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
245+
None,
246+
Some("[4.0,5.0,6.0]".to_string()),
247+
Some("[7.0,8.0,9.0]".to_string()),
248+
]))];
249+
vec_avg.update_batch(&v).unwrap();
250+
assert_eq!(
251+
ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))),
252+
vec_avg.evaluate().unwrap()
253+
);
254+
255+
// test update with constant vector
256+
let mut vec_avg = VectorAvg::default();
257+
let v: Vec<ArrayRef> = vec![
258+
Arc::new(ConstantVector::new(
259+
Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
260+
4,
261+
))
262+
.to_arrow_array(),
263+
];
264+
vec_avg.update_batch(&v).unwrap();
265+
assert_eq!(
266+
ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
267+
vec_avg.evaluate().unwrap()
268+
);
269+
}
270+
}

src/common/function/src/scalars/vector.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
mod convert;
1616
mod distance;
17+
mod elem_avg;
1718
mod elem_product;
1819
mod elem_sum;
1920
pub mod impl_conv;
@@ -64,6 +65,7 @@ impl VectorFunction {
6465
registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
6566
registry.register_scalar(elem_sum::ElemSumFunction::default());
6667
registry.register_scalar(elem_product::ElemProductFunction::default());
68+
registry.register_scalar(elem_avg::ElemAvgFunction::default());
6769
}
6870
}
6971

0 commit comments

Comments
 (0)