diff --git a/Cargo.lock b/Cargo.lock index f74eab595aa12..2efa3164395d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3274,6 +3274,7 @@ dependencies = [ "num-bigint", "num-traits", "pretty_assertions", + "proptest", "rand", "recursive", "roaring", diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index 8446592befff4..40597e22e7805 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -69,6 +69,7 @@ arrow-ord = { workspace = true } criterion = { workspace = true } goldenfile = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true } rand = { workspace = true } [[bench]] diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 95da2a33f2fd0..98e7dd2d864d0 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::fmt; use std::sync::Arc; use databend_common_column::bitmap::Bitmap; use databend_common_exception::Result; +use super::AggrState; +use super::AggrStateLoc; +use super::AggrStateRegistry; use super::StateAddr; -use crate::types::binary::BinaryColumnBuilder; +use crate::types::BinaryColumn; use crate::types::DataType; use crate::Column; use crate::ColumnBuilder; @@ -35,76 +37,61 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn name(&self) -> &str; fn return_type(&self) -> Result; - fn init_state(&self, place: StateAddr); + fn init_state(&self, place: AggrState); - fn is_state(&self) -> bool { - false - } - - fn state_layout(&self) -> Layout; + fn register_state(&self, registry: &mut AggrStateRegistry); // accumulate is to accumulate the arrays in batch mode // common used when there is no group by for aggregate function fn accumulate( &self, - _place: StateAddr, - _columns: InputColumns, - _validity: Option<&Bitmap>, - _input_rows: usize, + place: AggrState, + columns: InputColumns, + validity: Option<&Bitmap>, + input_rows: usize, ) -> Result<()>; // used when we need to calculate with group keys fn accumulate_keys( &self, - places: &[StateAddr], - offset: usize, + addrs: &[StateAddr], + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { - for (row, place) in places.iter().enumerate() { - self.accumulate_row(place.next(offset), columns, row)?; + for (row, addr) in addrs.iter().enumerate() { + self.accumulate_row(AggrState::new(*addr, loc), columns, row)?; } Ok(()) } // Used in aggregate_null_adaptor - fn accumulate_row(&self, _place: StateAddr, _columns: InputColumns, _row: usize) -> Result<()>; + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()>; - // serialize the state into binary array - fn batch_serialize( - &self, - places: &[StateAddr], - offset: usize, - builder: &mut BinaryColumnBuilder, - ) -> Result<()> { - for place in places { - self.serialize(place.next(offset), &mut builder.data)?; - builder.commit_row(); - } - Ok(()) - } - - fn serialize(&self, _place: StateAddr, _writer: &mut Vec) -> Result<()>; + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()>; fn serialize_size_per_row(&self) -> Option { None } - fn merge(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()>; + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>; /// Batch merge and deserialize the state from binary array - fn batch_merge(&self, places: &[StateAddr], offset: usize, column: &Column) -> Result<()> { - let c = column.as_binary().unwrap(); - for (place, mut data) in places.iter().zip(c.iter()) { - self.merge(place.next(offset), &mut data)?; + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BinaryColumn, + ) -> Result<()> { + for (place, mut data) in places.iter().zip(state.iter()) { + self.merge(AggrState::new(*place, loc), &mut data)?; } Ok(()) } - fn batch_merge_single(&self, place: StateAddr, column: &Column) -> Result<()> { - let c = column.as_binary().unwrap(); - + fn batch_merge_single(&self, place: AggrState, state: &Column) -> Result<()> { + let c = state.as_binary().unwrap(); for mut data in c.iter() { self.merge(place, &mut data)?; } @@ -115,29 +102,29 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { &self, places: &[StateAddr], rhses: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], ) -> Result<()> { for (place, rhs) in places.iter().zip(rhses.iter()) { - self.merge_states(place.next(offset), rhs.next(offset))?; + self.merge_states(AggrState::new(*place, loc), AggrState::new(*rhs, loc))?; } Ok(()) } - fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()>; + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()>; fn batch_merge_result( &self, places: &[StateAddr], - offset: usize, + loc: Box<[AggrStateLoc]>, builder: &mut ColumnBuilder, ) -> Result<()> { for place in places { - self.merge_result(place.next(offset), builder)?; + self.merge_result(AggrState::new(*place, &loc), builder)?; } Ok(()) } - // TODO append the value into the column builder - fn merge_result(&self, _place: StateAddr, _builder: &mut ColumnBuilder) -> Result<()>; + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()>; // std::mem::needs_drop:: // if true will call drop_state @@ -147,7 +134,7 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { /// # Safety /// The caller must ensure that the [`_place`] has defined memory. - unsafe fn drop_state(&self, _place: StateAddr) {} + unsafe fn drop_state(&self, _place: AggrState) {} fn get_own_null_adaptor( &self, diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 782ed6fd835c6..5d1fe21ac4379 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -15,10 +15,11 @@ use std::alloc::Layout; use std::ptr::NonNull; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use enum_as_inner::EnumAsInner; use super::AggregateFunctionRef; +use crate::types::binary::BinaryColumnBuilder; #[derive(Clone, Copy, Debug)] pub struct StateAddr { @@ -110,22 +111,241 @@ impl From for usize { } } -pub fn get_layout_offsets( - funcs: &[AggregateFunctionRef], - offsets: &mut Vec, -) -> Result { +pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result { + let mut registry = AggrStateRegistry::default(); + let mut serialize_size = Vec::with_capacity(funcs.len()); + for func in funcs { + func.register_state(&mut registry); + registry.commit(); + serialize_size.push(func.serialize_size_per_row()); + } + + let AggrStateRegistry { states, offsets } = registry; + + let (layout, locs) = sort_states(states); + + let states_loc = offsets + .windows(2) + .map(|w| locs[w[0]..w[1]].to_vec().into_boxed_slice()) + .collect::>(); + + Ok(StatesLayout { + layout, + states_loc, + serialize_size, + }) +} + +fn sort_states(states: Vec) -> (Layout, Vec) { + let mut states = states + .iter() + .enumerate() + .map(|(idx, state)| { + let layout = match state { + AggrStateType::Bool => (1, 1), + AggrStateType::Custom(layout) => (layout.align(), layout.pad_to_align().size()), + }; + (idx, state, layout) + }) + .collect::>(); + + states.sort_by_key(|(_, _, (align, _))| std::cmp::Reverse(*align)); + + let mut locs = vec![AggrStateLoc::Bool(0, 0); states.len()]; + let mut acc = 0; let mut max_align = 0; - let mut total_size: usize = 0; + for (idx, state, (align, size)) in states { + max_align = max_align.max(align); + let offset = acc; + acc += size; + locs[idx] = match state { + AggrStateType::Bool => AggrStateLoc::Bool(idx, offset), + AggrStateType::Custom(_) => AggrStateLoc::Custom(idx, offset), + }; + } - for func in funcs { - let layout = func.state_layout(); + let layout = Layout::from_size_align(acc, max_align).unwrap(); + + (layout, locs) +} + +#[derive(Debug, Clone, Copy, EnumAsInner)] +pub enum AggrStateLoc { + Bool(usize, usize), // index, offset + Custom(usize, usize), // index, offset +} + +impl AggrStateLoc { + pub fn offset(&self) -> usize { + match self { + AggrStateLoc::Bool(_, offset) => *offset, + AggrStateLoc::Custom(_, offset) => *offset, + } + } + + pub fn index(&self) -> usize { + match self { + AggrStateLoc::Bool(idx, _) => *idx, + AggrStateLoc::Custom(idx, _) => *idx, + } + } +} + +#[derive(Debug, Clone)] +pub struct StatesLayout { + pub layout: Layout, + pub states_loc: Vec>, + serialize_size: Vec>, +} + +impl StatesLayout { + pub fn serialize_builders(&self, num_rows: usize) -> Vec { + self.serialize_size + .iter() + .map(|size| BinaryColumnBuilder::with_capacity(num_rows, num_rows * size.unwrap_or(0))) + .collect() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AggrState<'a> { + pub addr: StateAddr, + pub loc: &'a [AggrStateLoc], +} + +impl<'a> AggrState<'a> { + pub fn new(addr: StateAddr, loc: &'a [AggrStateLoc]) -> Self { + Self { addr, loc } + } + + pub fn get<'b, T>(&self) -> &'b mut T { + debug_assert_eq!(self.loc.len(), 1); + self.addr + .next(self.loc[0].into_custom().unwrap().1) + .get::() + } + + pub fn write(&self, f: F) + where F: FnOnce() -> T { + debug_assert_eq!(self.loc.len(), 1); + self.addr + .next(self.loc[0].into_custom().unwrap().1) + .write(f); + } + + pub fn remove_last_loc(&self) -> Self { + debug_assert!(self.loc.len() >= 2); + Self { + addr: self.addr, + loc: &self.loc[..self.loc.len() - 1], + } + } + + pub fn remove_first_loc(&self) -> Self { + debug_assert!(self.loc.len() >= 2); + Self { + addr: self.addr, + loc: &self.loc[1..], + } + } +} + +pub struct AggrStateRegistry { + states: Vec, + offsets: Vec, +} + +impl AggrStateRegistry { + pub fn new() -> Self { + Self { + states: vec![], + offsets: vec![0], + } + } + + pub fn register(&mut self, state: AggrStateType) { + self.states.push(state); + } + + pub fn commit(&mut self) { + self.offsets.push(self.states.len()); + } + + pub fn states(&self) -> &[AggrStateType] { + &self.states + } +} + +impl Default for AggrStateRegistry { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Copy)] +pub enum AggrStateType { + Bool, + Custom(Layout), +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use proptest::strategy::ValueTree; + use proptest::test_runner::TestRunner; + + use super::*; + + prop_compose! { + fn arb_state_type()(size in 1..100_usize, align in 0..5_u8) -> AggrStateType { + let layout = Layout::from_size_align(size, 1 << align).unwrap(); + AggrStateType::Custom(layout) + } + } + + #[test] + fn test_sort_states() { + let mut runner = TestRunner::default(); + let input_s = prop::collection::vec(arb_state_type(), 1..20); + + for _ in 0..100 { + let input = input_s.new_tree(&mut runner).unwrap().current(); + run_sort_states(input); + } + } + + fn check_offset(layout: &Layout, offset: usize) -> bool { let align = layout.align(); + offset & (align - 1) == 0 + } - total_size = total_size.div_ceil(align) * align; - offsets.push(total_size); - max_align = max_align.max(align); - total_size += layout.size(); + fn run_sort_states(input: Vec) { + let (layout, locs) = sort_states(input.clone()); + + let is_aligned = input + .iter() + .zip(locs.iter()) + .all(|(state, loc)| match state { + AggrStateType::Custom(layout) => check_offset(layout, loc.offset()), + _ => unreachable!(), + }); + + assert!(is_aligned, "states are not aligned, input: {input:?}"); + + let size = layout.size(); + let mut memory = vec![false; size]; + for (state, loc) in input.iter().zip(locs.iter()) { + match state { + AggrStateType::Custom(layout) => { + let start = loc.offset(); + let end = start + layout.size(); + for i in start..end { + assert!(!memory[i], "layout is overlap, input: {input:?}"); + memory[i] = true; + } + } + _ => unreachable!(), + } + } } - Layout::from_size_align(total_size, max_align) - .map_err(|e| ErrorCode::LayoutError(format!("Layout error: {}", e))) } diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 07d403518e0b1..3fae88f9f636b 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -199,26 +199,26 @@ impl AggregateHashTable { } let state_places = &state.state_places.as_slice()[0..row_count]; - + let states_layout = self.payload.states_layout.as_ref().unwrap(); if agg_states.is_empty() { - for ((aggr, params), addr_offset) in self + for ((func, params), loc) in self .payload .aggrs .iter() .zip(params.iter()) - .zip(self.payload.state_addr_offsets.iter()) + .zip(states_layout.states_loc.iter()) { - aggr.accumulate_keys(state_places, *addr_offset, *params, row_count)?; + func.accumulate_keys(state_places, loc, *params, row_count)?; } } else { - for ((aggr, agg_state), addr_offset) in self + for ((func, state), loc) in self .payload .aggrs .iter() .zip(agg_states.iter()) - .zip(self.payload.state_addr_offsets.iter()) + .zip(states_layout.states_loc.iter()) { - aggr.batch_merge(state_places, *addr_offset, agg_state)?; + func.batch_merge(state_places, loc, state.as_binary().unwrap())?; } } } @@ -412,13 +412,10 @@ impl AggregateHashTable { let state = &mut flush_state.probe_state; let places = &state.state_places.as_slice()[0..row_count]; let rhses = &flush_state.state_places.as_slice()[0..row_count]; - for (aggr, addr_offset) in self - .payload - .aggrs - .iter() - .zip(self.payload.state_addr_offsets.iter()) - { - aggr.batch_merge_states(places, rhses, *addr_offset)?; + if let Some(layout) = self.payload.states_layout.as_ref() { + for (aggr, loc) in self.payload.aggrs.iter().zip(layout.states_loc.iter()) { + aggr.batch_merge_states(places, rhses, loc)?; + } } } @@ -426,29 +423,31 @@ impl AggregateHashTable { } pub fn merge_result(&mut self, flush_state: &mut PayloadFlushState) -> Result { - if self.payload.flush(flush_state) { - let row_count = flush_state.row_count; + if !self.payload.flush(flush_state) { + return Ok(false); + } - flush_state.aggregate_results.clear(); - for (aggr, addr_offset) in self + let row_count = flush_state.row_count; + flush_state.aggregate_results.clear(); + if let Some(states_layout) = self.payload.states_layout.as_ref() { + for (aggr, loc) in self .payload .aggrs .iter() - .zip(self.payload.state_addr_offsets.iter()) + .zip(states_layout.states_loc.iter().cloned()) { let return_type = aggr.return_type()?; let mut builder = ColumnBuilder::with_capacity(&return_type, row_count * 4); aggr.batch_merge_result( &flush_state.state_places.as_slice()[0..row_count], - *addr_offset, + loc, &mut builder, )?; flush_state.aggregate_results.push(builder.build()); } - return Ok(true); } - Ok(false) + Ok(true) } fn maybe_repartition(&mut self) -> bool { diff --git a/src/query/expression/src/aggregate/partitioned_payload.rs b/src/query/expression/src/aggregate/partitioned_payload.rs index 2eca5510d948e..5b27d6939f330 100644 --- a/src/query/expression/src/aggregate/partitioned_payload.rs +++ b/src/query/expression/src/aggregate/partitioned_payload.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::sync::Arc; use bumpalo::Bump; @@ -20,11 +19,13 @@ use itertools::Itertools; use super::payload::Payload; use super::probe_state::ProbeState; +use crate::get_states_layout; use crate::read; use crate::types::DataType; use crate::AggregateFunctionRef; use crate::InputColumns; use crate::PayloadFlushState; +use crate::StatesLayout; use crate::BATCH_SIZE; pub struct PartitionedPayload { @@ -37,8 +38,7 @@ pub struct PartitionedPayload { pub validity_offsets: Vec, pub hash_offset: usize, pub state_offset: usize, - pub state_addr_offsets: Vec, - pub state_layout: Option, + pub states_layout: Option, pub arenas: Vec>, @@ -60,8 +60,21 @@ impl PartitionedPayload { let radix_bits = partition_count.trailing_zeros() as u64; debug_assert_eq!(1 << radix_bits, partition_count); + let states_layout = if !aggrs.is_empty() { + Some(get_states_layout(&aggrs).unwrap()) + } else { + None + }; + let payloads = (0..partition_count) - .map(|_| Payload::new(arenas[0].clone(), group_types.clone(), aggrs.clone())) + .map(|_| { + Payload::new( + arenas[0].clone(), + group_types.clone(), + aggrs.clone(), + states_layout.clone(), + ) + }) .collect_vec(); let group_sizes = payloads[0].group_sizes.clone(); @@ -69,8 +82,6 @@ impl PartitionedPayload { let validity_offsets = payloads[0].validity_offsets.clone(); let hash_offset = payloads[0].hash_offset; let state_offset = payloads[0].state_offset; - let state_addr_offsets = payloads[0].state_addr_offsets.clone(); - let state_layout = payloads[0].state_layout; PartitionedPayload { payloads, @@ -81,8 +92,7 @@ impl PartitionedPayload { validity_offsets, hash_offset, state_offset, - state_addr_offsets, - state_layout, + states_layout, partition_count, arenas, diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index ad76f6dfe5927..788f187ed9699 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::mem::MaybeUninit; use std::sync::Arc; @@ -24,10 +23,10 @@ use strength_reduce::StrengthReducedU64; use super::payload_row::rowformat_size; use super::payload_row::serialize_column_to_rowformat; -use crate::get_layout_offsets; use crate::read; use crate::store; use crate::types::DataType; +use crate::AggrState; use crate::AggregateFunctionRef; use crate::Column; use crate::ColumnBuilder; @@ -36,6 +35,7 @@ use crate::InputColumns; use crate::PayloadFlushState; use crate::SelectVector; use crate::StateAddr; +use crate::StatesLayout; use crate::BATCH_SIZE; use crate::MAX_PAGE_SIZE; @@ -66,8 +66,7 @@ pub struct Payload { pub validity_offsets: Vec, pub hash_offset: usize, pub state_offset: usize, - pub state_addr_offsets: Vec, - pub state_layout: Option, + pub states_layout: Option, // if set, the payload contains at least duplicate rows pub min_cardinality: Option, @@ -93,20 +92,13 @@ impl Page { pub type Pages = Vec; -// TODO FIXME impl Payload { pub fn new( arena: Arc, group_types: Vec, aggrs: Vec, + states_layout: Option, ) -> Self { - let mut state_addr_offsets = Vec::new(); - let state_layout = if !aggrs.is_empty() { - Some(get_layout_offsets(&aggrs, &mut state_addr_offsets).unwrap()) - } else { - None - }; - let mut tuple_size = 0; let mut validity_offsets = Vec::with_capacity(group_types.len()); for x in group_types.iter() { @@ -156,8 +148,7 @@ impl Payload { validity_offsets, hash_offset, state_offset, - state_addr_offsets, - state_layout, + states_layout, } } @@ -299,10 +290,21 @@ impl Payload { write_offset += 8; debug_assert!(write_offset == self.state_offset); - if let Some(layout) = self.state_layout { + if let Some(StatesLayout { + layout, states_loc, .. + }) = &self.states_layout + { // write states - for idx in select_vector.iter().take(new_group_rows).copied() { - let place = self.arena.alloc_layout(layout); + let (array_layout, padded_size) = layout.repeat(new_group_rows).unwrap(); + // Bump only allocates but does not drop, so there is no use after free for any item. + let place = self.arena.alloc_layout(array_layout); + for (idx, place) in select_vector + .iter() + .take(new_group_rows) + .copied() + .enumerate() + .map(|(i, idx)| (idx, unsafe { place.add(padded_size * i) })) + { unsafe { let dst = address[idx].add(write_offset); store::(&(place.as_ptr() as u64), dst as *mut u8); @@ -310,8 +312,8 @@ impl Payload { let place = StateAddr::from(place); let page = &mut self.pages[page_index[idx]]; - for (aggr, offset) in self.aggrs.iter().zip(self.state_addr_offsets.iter()) { - aggr.init_state(place.next(*offset)); + for (aggr, loc) in self.aggrs.iter().zip(states_loc.iter()) { + aggr.init_state(AggrState::new(place, loc)); page.state_offsets += 1; } } @@ -438,35 +440,44 @@ impl Drop for Payload { fn drop(&mut self) { drop_guard(move || { // drop states - if !self.state_move_out { - 'FOR: for (idx, (aggr, addr_offset)) in self - .aggrs - .iter() - .zip(self.state_addr_offsets.iter()) - .enumerate() - { - if aggr.need_manual_drop_state() { - for page in self.pages.iter() { - let is_partial_state = page.is_partial_state(self.aggrs.len()); - - if is_partial_state && idx == 0 { - info!("Cleaning partial page, state_offsets: {}, row: {}, agg length: {}", page.state_offsets, page.rows, self.aggrs.len()); - } - for row in 0..page.state_offsets.div_ceil(self.aggrs.len()) { - // When OOM, some states are not initialized, we don't need to destroy them - if is_partial_state - && row * self.aggrs.len() + idx >= page.state_offsets - { - continue 'FOR; - } - let ptr = self.data_ptr(page, row); - unsafe { - let state_addr = - read::(ptr.add(self.state_offset) as _) as usize; - let state_place = StateAddr::new(state_addr); - aggr.drop_state(state_place.next(*addr_offset)); - } - } + if self.state_move_out { + return; + } + + let Some(states_layout) = self.states_layout.as_ref() else { + return; + }; + + 'FOR: for (idx, (aggr, loc)) in self + .aggrs + .iter() + .zip(states_layout.states_loc.iter()) + .enumerate() + { + if !aggr.need_manual_drop_state() { + continue; + } + + for page in self.pages.iter() { + let is_partial_state = page.is_partial_state(self.aggrs.len()); + + if is_partial_state && idx == 0 { + info!( + "Cleaning partial page, state_offsets: {}, row: {}, agg length: {}", + page.state_offsets, + page.rows, + self.aggrs.len() + ); + } + for row in 0..page.state_offsets.div_ceil(self.aggrs.len()) { + // When OOM, some states are not initialized, we don't need to destroy them + if is_partial_state && row * self.aggrs.len() + idx >= page.state_offsets { + continue 'FOR; + } + let ptr = self.data_ptr(page, row); + unsafe { + let state_addr = read::(ptr.add(self.state_offset) as _) as usize; + aggr.drop_state(AggrState::new(StateAddr::new(state_addr), loc)); } } } diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index e4b4735394c17..4fe9f35830227 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -19,6 +19,7 @@ use ethnum::i256; use super::partitioned_payload::PartitionedPayload; use super::payload::Payload; use super::probe_state::ProbeState; +use super::AggrState; use crate::read; use crate::types::binary::BinaryColumn; use crate::types::binary::BinaryColumnBuilder; @@ -127,40 +128,40 @@ impl Payload { } pub fn aggregate_flush(&self, state: &mut PayloadFlushState) -> Result> { - if self.flush(state) { - let row_count = state.row_count; + if !self.flush(state) { + return Ok(None); + } + + let row_count = state.row_count; - let mut state_builders: Vec = self - .aggrs - .iter() - .map(|_| BinaryColumnBuilder::with_capacity(row_count, row_count * 4)) - .collect(); + let mut cols = Vec::with_capacity(self.aggrs.len() + self.group_types.len()); + if let Some(state_layout) = self.states_layout.as_ref() { + let mut builders = state_layout.serialize_builders(row_count); for place in state.state_places.as_slice()[0..row_count].iter() { - for (idx, (addr_offset, aggr)) in self - .state_addr_offsets + for (idx, (loc, func)) in state_layout + .states_loc .iter() .zip(self.aggrs.iter()) .enumerate() { - let arg_place = place.next(*addr_offset); - aggr.serialize(arg_place, &mut state_builders[idx].data) - .unwrap(); - state_builders[idx].commit_row(); + { + let builder = &mut builders[idx]; + func.serialize(AggrState::new(*place, loc), &mut builder.data)?; + builder.commit_row(); + } } } - let mut cols = Vec::with_capacity(self.aggrs.len() + self.group_types.len()); - for builder in state_builders.into_iter() { - let col = Column::Binary(builder.build()); - cols.push(col); - } - - cols.extend_from_slice(&state.take_group_columns()); - return Ok(Some(DataBlock::new_from_columns(cols))); + cols.extend( + builders + .into_iter() + .map(|builder| Column::Binary(builder.build())), + ); } - Ok(None) + cols.extend_from_slice(&state.take_group_columns()); + Ok(Some(DataBlock::new_from_columns(cols))) } pub fn group_by_flush_all(&self) -> Result { diff --git a/src/query/expression/src/lib.rs b/src/query/expression/src/lib.rs index 2b561a42c9ee3..759b73536f43c 100755 --- a/src/query/expression/src/lib.rs +++ b/src/query/expression/src/lib.rs @@ -39,6 +39,7 @@ #![feature(try_blocks)] #![feature(let_chains)] #![feature(trait_upcasting)] +#![feature(alloc_layout_extra)] #[allow(dead_code)] mod block; diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index c065904c94e1b..af5a9c48b0c3b 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -12,16 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt; +use std::sync::Arc; + use databend_common_exception::Result; +use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::utils::column_merge_validity; +use databend_common_expression::ColumnBuilder; +use databend_common_expression::InputColumns; use databend_common_expression::Scalar; +use databend_common_io::prelude::BinaryWrite; -use super::aggregate_null_variadic_adaptor::AggregateNullVariadicAdaptor; -use super::AggregateNullUnaryAdaptor; -use crate::aggregates::aggregate_function_factory::AggregateFunctionFeatures; -use crate::aggregates::aggregate_null_result::AggregateNullResultFunction; -use crate::aggregates::AggregateFunctionRef; +use super::AggrState; +use super::AggrStateLoc; +use super::AggrStateRegistry; +use super::AggrStateType; +use super::AggregateFunction; +use super::AggregateFunctionFeatures; +use super::AggregateFunctionRef; +use super::AggregateNullResultFunction; +use super::StateAddr; #[derive(Clone)] pub struct AggregateFunctionCombinatorNull {} @@ -55,7 +67,7 @@ impl AggregateFunctionCombinatorNull { properties: AggregateFunctionFeatures, ) -> Result { // has_null_types - if !arguments.is_empty() && arguments.iter().any(|f| f == &DataType::Null) { + if arguments.iter().any(|f| f == &DataType::Null) { if properties.returns_default_when_only_null { return AggregateNullResultFunction::try_create(DataType::Number( NumberDataType::UInt64, @@ -92,3 +104,483 @@ impl AggregateFunctionCombinatorNull { } } } + +#[derive(Clone)] +pub struct AggregateNullUnaryAdaptor( + CommonNullAdaptor, +); + +impl AggregateNullUnaryAdaptor { + pub fn create(nested: AggregateFunctionRef) -> AggregateFunctionRef { + Arc::new(Self(CommonNullAdaptor:: { nested })) + } +} + +impl AggregateFunction for AggregateNullUnaryAdaptor { + fn name(&self) -> &str { + "AggregateNullUnaryAdaptor" + } + + fn return_type(&self) -> Result { + self.0.return_type() + } + + #[inline] + fn init_state(&self, place: AggrState) { + self.0.init_state(place); + } + + fn serialize_size_per_row(&self) -> Option { + self.0.serialize_size_per_row() + } + + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.0.register_state(registry); + } + + #[inline] + fn accumulate( + &self, + place: AggrState, + columns: InputColumns, + validity: Option<&Bitmap>, + input_rows: usize, + ) -> Result<()> { + let col = &columns[0]; + let validity = column_merge_validity(col, validity.cloned()); + let not_null_column = &[col.remove_nullable()]; + let not_null_column = not_null_column.into(); + let validity = Bitmap::map_all_sets_to_none(validity); + + self.0 + .accumulate(place, not_null_column, validity, input_rows) + } + + #[inline] + fn accumulate_keys( + &self, + addrs: &[StateAddr], + loc: &[AggrStateLoc], + columns: InputColumns, + input_rows: usize, + ) -> Result<()> { + let col = &columns[0]; + let validity = column_merge_validity(col, None); + let not_null_columns = &[col.remove_nullable()]; + let not_null_columns = not_null_columns.into(); + + self.0 + .accumulate_keys(addrs, loc, not_null_columns, validity, input_rows) + } + + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { + let col = &columns[0]; + let validity = column_merge_validity(col, None); + let not_null_columns = &[col.remove_nullable()]; + let not_null_columns = not_null_columns.into(); + + self.0 + .accumulate_row(place, not_null_columns, validity, row) + } + + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.0.serialize(place, writer) + } + + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + self.0.merge(place, reader) + } + + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + self.0.merge_states(place, rhs) + } + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + self.0.merge_result(place, builder) + } + + fn need_manual_drop_state(&self) -> bool { + self.0.nested.need_manual_drop_state() + } + + unsafe fn drop_state(&self, place: AggrState) { + self.0.drop_state(place); + } + + fn convert_const_to_full(&self) -> bool { + self.0.nested.convert_const_to_full() + } + + fn get_if_condition(&self, columns: InputColumns) -> Option { + self.0.nested.get_if_condition(columns) + } +} + +impl fmt::Display for AggregateNullUnaryAdaptor { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AggregateNullUnaryAdaptor") + } +} + +#[derive(Clone)] +pub struct AggregateNullVariadicAdaptor( + CommonNullAdaptor, +); + +impl AggregateNullVariadicAdaptor { + pub fn create(nested: AggregateFunctionRef) -> AggregateFunctionRef { + Arc::new(Self(CommonNullAdaptor:: { nested })) + } +} + +impl AggregateFunction + for AggregateNullVariadicAdaptor +{ + fn name(&self) -> &str { + "AggregateNullVariadicAdaptor" + } + + fn return_type(&self) -> Result { + self.0.return_type() + } + + fn init_state(&self, place: AggrState) { + self.0.init_state(place); + } + + fn serialize_size_per_row(&self) -> Option { + self.0.serialize_size_per_row() + } + + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.0.register_state(registry); + } + + #[inline] + fn accumulate( + &self, + place: AggrState, + columns: InputColumns, + validity: Option<&Bitmap>, + input_rows: usize, + ) -> Result<()> { + let mut not_null_columns = Vec::with_capacity(columns.len()); + let mut validity = validity.cloned(); + for col in columns.iter() { + validity = column_merge_validity(col, validity); + not_null_columns.push(col.remove_nullable()); + } + let not_null_columns = (¬_null_columns).into(); + + self.0 + .accumulate(place, not_null_columns, validity, input_rows) + } + + fn accumulate_keys( + &self, + addrs: &[StateAddr], + loc: &[AggrStateLoc], + columns: InputColumns, + input_rows: usize, + ) -> Result<()> { + let mut not_null_columns = Vec::with_capacity(columns.len()); + let mut validity = None; + for col in columns.iter() { + validity = column_merge_validity(col, validity); + not_null_columns.push(col.remove_nullable()); + } + let not_null_columns = (¬_null_columns).into(); + + self.0 + .accumulate_keys(addrs, loc, not_null_columns, validity, input_rows) + } + + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { + let mut not_null_columns = Vec::with_capacity(columns.len()); + let mut validity = None; + for col in columns.iter() { + validity = column_merge_validity(col, validity); + not_null_columns.push(col.remove_nullable()); + } + let not_null_columns = (¬_null_columns).into(); + + self.0 + .accumulate_row(place, not_null_columns, validity, row) + } + + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.0.serialize(place, writer) + } + + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + self.0.merge(place, reader) + } + + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + self.0.merge_states(place, rhs) + } + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + self.0.merge_result(place, builder) + } + + fn need_manual_drop_state(&self) -> bool { + self.0.nested.need_manual_drop_state() + } + + unsafe fn drop_state(&self, place: AggrState) { + self.0.drop_state(place); + } + + fn convert_const_to_full(&self) -> bool { + self.0.nested.convert_const_to_full() + } + + fn get_if_condition(&self, columns: InputColumns) -> Option { + self.0.nested.get_if_condition(columns) + } +} + +impl fmt::Display for AggregateNullVariadicAdaptor { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AggregateNullVariadicAdaptor") + } +} + +#[derive(Clone)] +struct CommonNullAdaptor { + nested: AggregateFunctionRef, +} + +impl CommonNullAdaptor { + fn return_type(&self) -> Result { + if !NULLABLE_RESULT { + return self.nested.return_type(); + } + + let nested = self.nested.return_type()?; + Ok(nested.wrap_nullable()) + } + + fn init_state(&self, place: AggrState) { + if !NULLABLE_RESULT { + return self.nested.init_state(place); + } + + set_flag(place, false); + self.nested.init_state(place.remove_last_loc()); + } + + fn serialize_size_per_row(&self) -> Option { + self.nested + .serialize_size_per_row() + .map(|row| if NULLABLE_RESULT { row + 1 } else { row }) + } + + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.nested.register_state(registry); + if NULLABLE_RESULT { + registry.register(AggrStateType::Bool); + } + } + + #[inline] + fn accumulate( + &self, + place: AggrState, + not_null_column: InputColumns, + validity: Option, + input_rows: usize, + ) -> Result<()> { + if !NULLABLE_RESULT { + return self + .nested + .accumulate(place, not_null_column, validity.as_ref(), input_rows); + } + + if validity + .as_ref() + .map(|c| c.null_count() != input_rows) + .unwrap_or(true) + { + set_flag(place, true); + } + self.nested.accumulate( + place.remove_last_loc(), + not_null_column, + validity.as_ref(), + input_rows, + ) + } + + fn accumulate_keys( + &self, + addrs: &[StateAddr], + loc: &[AggrStateLoc], + not_null_columns: InputColumns, + validity: Option, + input_rows: usize, + ) -> Result<()> { + match validity { + Some(v) if v.null_count() > 0 => { + // all nulls + if v.null_count() == v.len() { + return Ok(()); + } + + for (valid, (row, place)) in v.iter().zip(addrs.iter().enumerate()) { + if !valid { + continue; + } + let place = AggrState::new(*place, loc); + if NULLABLE_RESULT { + set_flag(place, true); + self.nested.accumulate_row( + place.remove_last_loc(), + not_null_columns, + row, + )?; + } else { + self.nested.accumulate_row(place, not_null_columns, row)?; + }; + } + Ok(()) + } + _ => { + if !NULLABLE_RESULT { + self.nested + .accumulate_keys(addrs, loc, not_null_columns, input_rows) + } else { + addrs + .iter() + .for_each(|addr| set_flag(AggrState::new(*addr, loc), true)); + self.nested.accumulate_keys( + addrs, + &loc[..loc.len() - 1], + not_null_columns, + input_rows, + ) + } + } + } + } + + fn accumulate_row( + &self, + place: AggrState, + not_null_columns: InputColumns, + validity: Option, + row: usize, + ) -> Result<()> { + let v = if let Some(v) = validity { + if v.null_count() == 0 { + true + } else if v.null_count() == v.len() { + false + } else { + unsafe { v.get_bit_unchecked(row) } + } + } else { + true + }; + if !v { + return Ok(()); + } + + if !NULLABLE_RESULT { + return self.nested.accumulate_row(place, not_null_columns, row); + } + + set_flag(place, true); + self.nested + .accumulate_row(place.remove_last_loc(), not_null_columns, row) + } + + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.serialize(place, writer); + } + + self.nested.serialize(place.remove_last_loc(), writer)?; + let flag = get_flag(place); + writer.write_scalar(&flag) + } + + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.merge(place, reader); + } + + let flag = reader[reader.len() - 1]; + if flag == 0 { + return Ok(()); + } + + if !get_flag(place) { + // initial the state to remove the dirty stats + self.init_state(place); + } + set_flag(place, true); + self.nested + .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1]) + } + + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.merge_states(place, rhs); + } + + if !get_flag(rhs) { + return Ok(()); + } + + if !get_flag(place) { + // initial the state to remove the dirty stats + self.init_state(place); + } + set_flag(place, true); + self.nested + .merge_states(place.remove_last_loc(), rhs.remove_last_loc()) + } + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.merge_result(place, builder); + } + + let ColumnBuilder::Nullable(ref mut inner) = builder else { + unreachable!() + }; + + if get_flag(place) { + inner.validity.push(true); + self.nested + .merge_result(place.remove_last_loc(), &mut inner.builder) + } else { + inner.push_null(); + Ok(()) + } + } + + unsafe fn drop_state(&self, place: AggrState) { + if !NULLABLE_RESULT { + self.nested.drop_state(place) + } else { + self.nested.drop_state(place.remove_last_loc()) + } + } +} + +fn set_flag(place: AggrState, flag: bool) { + let c = place.addr.next(flag_offset(place)).get::(); + *c = flag as u8; +} + +fn get_flag(place: AggrState) -> bool { + let c = place.addr.next(flag_offset(place)).get::(); + *c != 0 +} + +fn flag_offset(place: AggrState) -> usize { + *place.loc.last().unwrap().as_bool().unwrap().1 +} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs deleted file mode 100644 index 90977e2c1f466..0000000000000 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::alloc::Layout; -use std::fmt; -use std::sync::Arc; - -use databend_common_exception::Result; -use databend_common_expression::types::Bitmap; -use databend_common_expression::types::DataType; -use databend_common_expression::utils::column_merge_validity; -use databend_common_expression::ColumnBuilder; -use databend_common_expression::InputColumns; -use databend_common_io::prelude::BinaryWrite; - -use crate::aggregates::AggregateFunction; -use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; - -#[derive(Clone)] -pub struct AggregateNullUnaryAdaptor { - nested: AggregateFunctionRef, - size_of_data: usize, -} - -impl AggregateNullUnaryAdaptor { - pub fn create(nested: AggregateFunctionRef) -> AggregateFunctionRef { - let size_of_data = if NULLABLE_RESULT { - let layout = nested.state_layout(); - layout.size() - } else { - 0 - }; - Arc::new(Self { - nested, - size_of_data, - }) - } - - #[inline] - pub fn set_flag(&self, place: StateAddr, flag: u8) { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c = flag; - } - } - - #[inline] - pub fn init_flag(&self, place: StateAddr) { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c = 0; - } - } - - #[inline] - pub fn get_flag(&self, place: StateAddr) -> u8 { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c - } else { - 1 - } - } -} - -impl AggregateFunction for AggregateNullUnaryAdaptor { - fn name(&self) -> &str { - "AggregateNullUnaryAdaptor" - } - - fn return_type(&self) -> Result { - let nested = self.nested.return_type()?; - match NULLABLE_RESULT { - true => Ok(nested.wrap_nullable()), - false => Ok(nested), - } - } - - #[inline] - fn init_state(&self, place: StateAddr) { - self.init_flag(place); - self.nested.init_state(place); - } - - fn serialize_size_per_row(&self) -> Option { - self.nested.serialize_size_per_row().map(|row| row + 1) - } - - #[inline] - fn state_layout(&self) -> Layout { - let layout = self.nested.state_layout(); - let add = if NULLABLE_RESULT { layout.align() } else { 0 }; - Layout::from_size_align(layout.size() + add, layout.align()).unwrap() - } - - #[inline] - fn accumulate( - &self, - place: StateAddr, - columns: InputColumns, - validity: Option<&Bitmap>, - input_rows: usize, - ) -> Result<()> { - let col = &columns[0]; - let validity = column_merge_validity(col, validity.cloned()); - let not_null_column = &[col.remove_nullable()]; - let not_null_column = not_null_column.into(); - let validity = Bitmap::map_all_sets_to_none(validity); - - self.nested - .accumulate(place, not_null_column, validity.as_ref(), input_rows)?; - - if validity - .as_ref() - .map(|c| c.null_count() != input_rows) - .unwrap_or(true) - { - self.set_flag(place, 1); - } - Ok(()) - } - - #[inline] - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: InputColumns, - input_rows: usize, - ) -> Result<()> { - let col = &columns[0]; - let validity = column_merge_validity(col, None); - let not_null_columns = &[col.remove_nullable()]; - let not_null_columns = not_null_columns.into(); - - match validity { - Some(v) if v.null_count() > 0 => { - // all nulls - if v.null_count() == v.len() { - return Ok(()); - } - - for (valid, (row, place)) in v.iter().zip(places.iter().enumerate()) { - if valid { - self.set_flag(place.next(offset), 1); - self.nested - .accumulate_row(place.next(offset), not_null_columns, row)?; - } - } - } - _ => { - self.nested - .accumulate_keys(places, offset, not_null_columns, input_rows)?; - places - .iter() - .for_each(|place| self.set_flag(place.next(offset), 1)); - } - } - - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { - let col = &columns[0]; - let validity = column_merge_validity(col, None); - let not_null_columns = &[col.remove_nullable()]; - let not_null_columns = not_null_columns.into(); - - match validity { - Some(v) if v.null_count() > 0 => { - // all nulls - if v.null_count() == v.len() { - return Ok(()); - } - - if unsafe { v.get_bit_unchecked(row) } { - self.set_flag(place, 1); - self.nested.accumulate_row(place, not_null_columns, row)?; - } - } - _ => { - self.nested.accumulate_row(place, not_null_columns, row)?; - self.set_flag(place, 1); - } - } - - Ok(()) - } - - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer)?; - if NULLABLE_RESULT { - let flag = self.get_flag(place); - writer.write_scalar(&flag)?; - } - Ok(()) - } - - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - if self.get_flag(place) == 0 { - // initial the state to remove the dirty stats - self.init_state(place); - } - - if NULLABLE_RESULT { - let flag = reader[reader.len() - 1]; - if flag == 1 { - self.set_flag(place, 1); - self.nested.merge(place, &mut &reader[..reader.len() - 1])?; - } - } else { - self.nested.merge(place, reader)?; - } - - Ok(()) - } - - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - if self.get_flag(place) == 0 { - // initial the state to remove the dirty stats - self.init_state(place); - } - - if self.get_flag(rhs) == 1 { - self.set_flag(place, 1); - self.nested.merge_states(place, rhs)?; - } - - Ok(()) - } - - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - if NULLABLE_RESULT { - if self.get_flag(place) == 1 { - match builder { - ColumnBuilder::Nullable(ref mut inner) => { - self.nested.merge_result(place, &mut inner.builder)?; - inner.validity.push(true); - } - _ => unreachable!(), - } - } else { - builder.push_default(); - } - Ok(()) - } else { - self.nested.merge_result(place, builder) - } - } - - fn need_manual_drop_state(&self) -> bool { - self.nested.need_manual_drop_state() - } - - unsafe fn drop_state(&self, place: StateAddr) { - self.nested.drop_state(place) - } - - fn convert_const_to_full(&self) -> bool { - self.nested.convert_const_to_full() - } - - fn get_if_condition(&self, columns: InputColumns) -> Option { - self.nested.get_if_condition(columns) - } -} - -impl fmt::Display for AggregateNullUnaryAdaptor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "AggregateNullUnaryAdaptor") - } -} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs deleted file mode 100644 index 57209ec9c4b28..0000000000000 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::alloc::Layout; -use std::fmt; -use std::sync::Arc; - -use databend_common_exception::Result; -use databend_common_expression::types::Bitmap; -use databend_common_expression::types::DataType; -use databend_common_expression::utils::column_merge_validity; -use databend_common_expression::ColumnBuilder; -use databend_common_expression::InputColumns; -use databend_common_io::prelude::BinaryWrite; - -use crate::aggregates::AggregateFunction; -use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; - -#[derive(Clone)] -pub struct AggregateNullVariadicAdaptor { - nested: AggregateFunctionRef, - size_of_data: usize, -} - -impl AggregateNullVariadicAdaptor { - pub fn create(nested: AggregateFunctionRef) -> AggregateFunctionRef { - let size_of_data = if NULLABLE_RESULT { - let layout = nested.state_layout(); - layout.size() - } else { - 0 - }; - Arc::new(Self { - nested, - size_of_data, - }) - } - - #[inline] - pub fn set_flag(&self, place: StateAddr, flag: u8) { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c = flag; - } - } - - #[inline] - pub fn init_flag(&self, place: StateAddr) { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c = 0; - } - } - - #[inline] - pub fn get_flag(&self, place: StateAddr) -> u8 { - if NULLABLE_RESULT { - let c = place.next(self.size_of_data).get::(); - *c - } else { - 1 - } - } -} - -impl AggregateFunction - for AggregateNullVariadicAdaptor -{ - fn name(&self) -> &str { - "AggregateNullVariadicAdaptor" - } - - fn return_type(&self) -> Result { - let nested = self.nested.return_type()?; - match NULLABLE_RESULT { - true => Ok(nested.wrap_nullable()), - false => Ok(nested), - } - } - - fn init_state(&self, place: StateAddr) { - self.init_flag(place); - self.nested.init_state(place); - } - - fn serialize_size_per_row(&self) -> Option { - self.nested.serialize_size_per_row().map(|row| row + 1) - } - - #[inline] - fn state_layout(&self) -> Layout { - let layout = self.nested.state_layout(); - let add = if NULLABLE_RESULT { layout.align() } else { 0 }; - Layout::from_size_align(layout.size() + add, layout.align()).unwrap() - } - - #[inline] - fn accumulate( - &self, - place: StateAddr, - columns: InputColumns, - validity: Option<&Bitmap>, - input_rows: usize, - ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = validity.cloned(); - for col in columns.iter() { - validity = column_merge_validity(col, validity); - not_null_columns.push(col.remove_nullable()); - } - let not_null_columns = (¬_null_columns).into(); - - self.nested - .accumulate(place, not_null_columns, validity.as_ref(), input_rows)?; - - if validity - .as_ref() - .map(|c| c.null_count() != input_rows) - .unwrap_or(true) - { - self.set_flag(place, 1); - } - Ok(()) - } - - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: InputColumns, - input_rows: usize, - ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for col in columns.iter() { - validity = column_merge_validity(col, validity); - not_null_columns.push(col.remove_nullable()); - } - let not_null_columns = (¬_null_columns).into(); - - match validity { - Some(v) if v.null_count() > 0 => { - // all nulls - if v.null_count() == v.len() { - return Ok(()); - } - for (valid, (row, place)) in v.iter().zip(places.iter().enumerate()) { - if valid { - self.set_flag(place.next(offset), 1); - self.nested - .accumulate_row(place.next(offset), not_null_columns, row)?; - } - } - } - _ => { - self.nested - .accumulate_keys(places, offset, not_null_columns, input_rows)?; - places - .iter() - .for_each(|place| self.set_flag(place.next(offset), 1)); - } - } - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for col in columns.iter() { - validity = column_merge_validity(col, validity); - not_null_columns.push(col.remove_nullable()); - } - let not_null_columns = (¬_null_columns).into(); - - match validity { - Some(v) if v.null_count() > 0 => { - // all nulls - if v.null_count() == v.len() { - return Ok(()); - } - - if unsafe { v.get_bit_unchecked(row) } { - self.set_flag(place, 1); - self.nested.accumulate_row(place, not_null_columns, row)?; - } - } - _ => { - self.nested.accumulate_row(place, not_null_columns, row)?; - self.set_flag(place, 1); - } - } - - Ok(()) - } - - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer)?; - if NULLABLE_RESULT { - let flag: u8 = self.get_flag(place); - writer.write_scalar(&flag)?; - } - Ok(()) - } - - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - if self.get_flag(place) == 0 { - // initial the state to remove the dirty stats - self.init_state(place); - } - - if NULLABLE_RESULT { - let flag = reader[reader.len() - 1]; - if flag == 1 { - self.set_flag(place, flag); - self.nested.merge(place, &mut &reader[..reader.len() - 1])?; - } - } else { - self.nested.merge(place, reader)?; - } - Ok(()) - } - - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - if self.get_flag(place) == 0 { - // initial the state to remove the dirty stats - self.init_state(place); - } - - if self.get_flag(rhs) == 1 { - self.set_flag(place, 1); - self.nested.merge_states(place, rhs)?; - } - Ok(()) - } - - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - if NULLABLE_RESULT { - if self.get_flag(place) == 1 { - match builder { - ColumnBuilder::Nullable(ref mut inner) => { - self.nested.merge_result(place, &mut inner.builder)?; - inner.validity.push(true); - } - _ => unreachable!(), - } - } else { - builder.push_default(); - } - Ok(()) - } else { - self.nested.merge_result(place, builder) - } - } - - fn need_manual_drop_state(&self) -> bool { - self.nested.need_manual_drop_state() - } - - unsafe fn drop_state(&self, place: StateAddr) { - self.nested.drop_state(place) - } - - fn convert_const_to_full(&self) -> bool { - self.nested.convert_const_to_full() - } - - fn get_if_condition(&self, columns: InputColumns) -> Option { - self.nested.get_if_condition(columns) - } -} - -impl fmt::Display for AggregateNullVariadicAdaptor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "AggregateNullVariadicAdaptor") - } -} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index f6c1e915e793a..7890e45225b47 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -12,22 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::fmt; use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; use databend_common_io::prelude::BinaryWrite; -use crate::aggregates::aggregate_function_factory::AggregateFunctionFeatures; -use crate::aggregates::AggregateFunction; -use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; +use super::AggrState; +use super::AggrStateLoc; +use super::AggregateFunction; +use super::AggregateFunctionFeatures; +use super::AggregateFunctionRef; +use super::StateAddr; /// OrNullAdaptor will use OrNull for aggregate functions. /// If there are no input values, return NULL or a default value, accordingly. @@ -35,7 +38,6 @@ use crate::aggregates::StateAddr; /// 0 means there was no input, 1 means there was some. pub struct AggregateFunctionOrNullAdaptor { inner: AggregateFunctionRef, - size_of_data: usize, inner_nullable: bool, } @@ -50,25 +52,25 @@ impl AggregateFunctionOrNullAdaptor { return Ok(inner); } - let inner_layout = inner.state_layout(); Ok(Arc::new(AggregateFunctionOrNullAdaptor { inner, - size_of_data: inner_layout.size(), - inner_nullable: matches!(inner_return_type, DataType::Nullable(_)), + inner_nullable: inner_return_type.is_nullable(), })) } +} - #[inline] - pub fn set_flag(&self, place: StateAddr, flag: u8) { - let c = place.next(self.size_of_data).get::(); - *c = flag; - } +pub fn set_flag(place: AggrState, flag: bool) { + let c = place.addr.next(flag_offset(place)).get::(); + *c = flag as u8; +} - #[inline] - pub fn get_flag(&self, place: StateAddr) -> u8 { - let c = place.next(self.size_of_data).get::(); - *c - } +pub fn get_flag(place: AggrState) -> bool { + let c = place.addr.next(flag_offset(place)).get::(); + *c != 0 +} + +fn flag_offset(place: AggrState) -> usize { + *place.loc.last().unwrap().as_bool().unwrap().1 } impl AggregateFunction for AggregateFunctionOrNullAdaptor { @@ -81,26 +83,25 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } #[inline] - fn init_state(&self, place: StateAddr) { - let c = place.next(self.size_of_data).get::(); + fn init_state(&self, place: AggrState) { + let c = place.addr.next(flag_offset(place)).get::(); *c = 0; - self.inner.init_state(place) + self.inner.init_state(place.remove_last_loc()) } fn serialize_size_per_row(&self) -> Option { self.inner.serialize_size_per_row().map(|row| row + 1) } - #[inline] - fn state_layout(&self) -> std::alloc::Layout { - let layout = self.inner.state_layout(); - Layout::from_size_align(layout.size() + layout.align(), layout.align()).unwrap() + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.inner.register_state(registry); + registry.register(AggrStateType::Bool); } #[inline] fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, @@ -123,9 +124,13 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .map(|c| c.null_count() != input_rows) .unwrap_or(true) { - self.set_flag(place, 1); - self.inner - .accumulate(place, columns, validity.as_ref(), input_rows)?; + set_flag(place, true); + self.inner.accumulate( + place.remove_last_loc(), + columns, + validity.as_ref(), + input_rows, + )?; } Ok(()) } @@ -133,12 +138,12 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, input_rows: usize, ) -> Result<()> { self.inner - .accumulate_keys(places, offset, columns, input_rows)?; + .accumulate_keys(places, &loc[..loc.len() - 1], columns, input_rows)?; let if_cond = self.inner.get_if_condition(columns); match if_cond { @@ -148,15 +153,15 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { return Ok(()); } - for (place, valid) in places.iter().zip(v.iter()) { + for (&addr, valid) in places.iter().zip(v.iter()) { if valid { - self.set_flag(place.next(offset), 1); + set_flag(AggrState::new(addr, loc), true); } } } _ => { - for place in places { - self.set_flag(place.next(offset), 1); + for &addr in places { + set_flag(AggrState::new(addr, loc), true); } } } @@ -165,43 +170,47 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } #[inline] - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { - self.inner.accumulate_row(place, columns, row)?; - self.set_flag(place, 1); + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { + self.inner + .accumulate_row(place.remove_last_loc(), columns, row)?; + set_flag(place, true); Ok(()) } #[inline] - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - self.inner.serialize(place, writer)?; - writer.write_scalar(&self.get_flag(place)) + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.inner.serialize(place.remove_last_loc(), writer)?; + let flag = get_flag(place) as u8; + writer.write_scalar(&flag) } #[inline] - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let flag = self.get_flag(place) > 0 || reader[reader.len() - 1] > 0; - - self.inner.merge(place, &mut &reader[..reader.len() - 1])?; - self.set_flag(place, flag as u8); + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + let flag = get_flag(place) || reader[reader.len() - 1] > 0; + self.inner + .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?; + set_flag(place, flag); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - self.inner.merge_states(place, rhs)?; - let flag = self.get_flag(place) > 0 || self.get_flag(rhs) > 0; - self.set_flag(place, u8::from(flag)); + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + self.inner + .merge_states(place.remove_last_loc(), rhs.remove_last_loc())?; + let flag = get_flag(place) || get_flag(rhs); + set_flag(place, flag); Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { match builder { ColumnBuilder::Nullable(inner_mut) => { - if self.get_flag(place) == 0 { + if !get_flag(place) { inner_mut.push_null(); } else if self.inner_nullable { - self.inner.merge_result(place, builder)?; + self.inner.merge_result(place.remove_last_loc(), builder)?; } else { - self.inner.merge_result(place, &mut inner_mut.builder)?; + self.inner + .merge_result(place.remove_last_loc(), &mut inner_mut.builder)?; inner_mut.validity.push(true); } } @@ -224,14 +233,15 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { self.inner.need_manual_drop_state() } - unsafe fn drop_state(&self, place: StateAddr) { - self.inner.drop_state(place) + unsafe fn drop_state(&self, place: AggrState) { + self.inner.drop_state(place.remove_last_loc()) } fn convert_const_to_full(&self) -> bool { self.inner.convert_const_to_full() } } + impl fmt::Display for AggregateFunctionOrNullAdaptor { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.inner) diff --git a/src/query/functions/src/aggregates/adaptors/mod.rs b/src/query/functions/src/aggregates/adaptors/mod.rs index 556381682ba7f..0afc7b032bdd7 100644 --- a/src/query/functions/src/aggregates/adaptors/mod.rs +++ b/src/query/functions/src/aggregates/adaptors/mod.rs @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::aggregate_function_factory::*; +use super::*; + mod aggregate_null_adaptor; -mod aggregate_null_unary_adaptor; -mod aggregate_null_variadic_adaptor; mod aggregate_ornull_adaptor; pub use aggregate_null_adaptor::*; -pub use aggregate_null_unary_adaptor::*; -pub use aggregate_null_variadic_adaptor::*; pub use aggregate_ornull_adaptor::*; diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index 289512ccdceb1..f0592d61e0815 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -25,6 +25,8 @@ use databend_common_expression::types::number::*; use databend_common_expression::types::Bitmap; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -42,6 +44,8 @@ use super::borsh_serialize_state; use super::AggregateFunctionRef; use super::StateAddr; use crate::aggregates::assert_binary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::with_compare_mapped_type; use crate::with_simple_no_number_mapped_type; @@ -216,17 +220,17 @@ where Ok(self.return_data_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); + fn init_state(&self, place: AggrState) { + place.write(State::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -240,7 +244,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -250,10 +254,9 @@ where val_col_iter .enumerate() - .zip(places.iter()) - .for_each(|((row, val), place)| { - let addr = place.next(offset); - let state = addr.get::(); + .zip(places.iter().cloned()) + .for_each(|((row, val), addr)| { + let state = AggrState::new(addr, loc).get::(); if state.change(&val) { state.update(val, A::index_column(&arg_col, row).unwrap()) } @@ -261,7 +264,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let arg_col = A::try_downcast_column(&columns[0]).unwrap(); let val_col = V::try_downcast_column(&columns[1]).unwrap(); let state = place.get::(); @@ -273,24 +276,24 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge_from(rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder) } @@ -299,7 +302,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 7400d75d8db01..fc7e7568f862b 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -28,6 +28,8 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -41,6 +43,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::assert_unary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::with_simple_no_number_mapped_type; @@ -268,17 +272,17 @@ where Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); + fn init_state(&self, place: AggrState) { + place.write(State::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, _validity: Option<&Bitmap>, _input_rows: usize, @@ -299,7 +303,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -310,8 +314,7 @@ where column_iter .zip(nullable_column.validity.iter().zip(places.iter())) .for_each(|(v, (valid, place))| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); if valid { state.add(Some(v.clone())) } else { @@ -323,8 +326,7 @@ where let column = T::try_downcast_column(&columns[0]).unwrap(); let column_iter = T::iter_column(&column); column_iter.zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(Some(v.clone())) }); } @@ -333,7 +335,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); match &columns[0] { Column::Nullable(box nullable_column) => { @@ -356,25 +358,25 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder) } @@ -383,7 +385,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 39fc115fd60ee..c9db4d2762fe5 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -37,6 +37,8 @@ use databend_common_expression::types::ValueType; use databend_common_expression::types::F64; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::Expr; @@ -56,6 +58,8 @@ use super::StateAddr; use crate::aggregates::aggregate_sum::SumState; use crate::aggregates::assert_unary_arguments; use crate::aggregates::assert_variadic_params; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::BUILTIN_FUNCTIONS; #[derive(Default, Debug, BorshDeserialize, BorshSerialize)] @@ -114,7 +118,7 @@ where Ok(()) } - fn accumulate_keys(places: &[StateAddr], offset: usize, columns: &Column) -> Result<()> { + fn accumulate_keys(places: &[StateAddr], loc: &[AggrStateLoc], columns: &Column) -> Result<()> { let buffer = match columns { Column::Null { len } => Buffer::from(vec![T::default(); *len]), Column::Nullable(box nullable_column) => { @@ -123,8 +127,7 @@ where _ => NumberType::::try_downcast_column(columns).unwrap(), }; buffer.iter().zip(places.iter()).for_each(|(c, place)| { - let place = place.next(offset); - let state = place.get::(); + let state = AggrState::new(*place, loc).get::(); state.values.push(*c); }); Ok(()) @@ -281,7 +284,7 @@ where T: Decimal Ok(()) } - fn accumulate_keys(places: &[StateAddr], offset: usize, columns: &Column) -> Result<()> { + fn accumulate_keys(places: &[StateAddr], loc: &[AggrStateLoc], columns: &Column) -> Result<()> { let buffer = match columns { Column::Null { len } => Buffer::from(vec![T::default(); *len]), Column::Nullable(box nullable_column) => { @@ -290,8 +293,7 @@ where T: Decimal _ => T::try_downcast_column(columns).unwrap().0, }; buffer.iter().zip(places.iter()).for_each(|(c, place)| { - let place = place.next(offset); - let state = place.get::(); + let state = AggrState::new(*place, loc).get::(); state.values.push(*c); }); Ok(()) @@ -403,17 +405,17 @@ where State: SumState Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::default()); + fn init_state(&self, place: AggrState) { + place.write(State::default); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -425,37 +427,37 @@ where State: SumState fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { - State::accumulate_keys(places, offset, &columns[0]) + State::accumulate_keys(places, loc, &columns[0]) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_avg_result(builder, 0_u64, self.scale_add, &self.window_size) } @@ -464,7 +466,7 @@ where State: SumState true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } @@ -597,17 +599,17 @@ where State: SumState Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::default()); + fn init_state(&self, place: AggrState) { + place.write(State::default); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -619,37 +621,37 @@ where State: SumState fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { - State::accumulate_keys(places, offset, &columns[0]) + State::accumulate_keys(places, loc, &columns[0]) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder, &self.window_size) } @@ -658,7 +660,7 @@ where State: SumState true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index 1713d1d02dc29..3b19addcde200 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -30,6 +30,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::MutableBitmap; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::Expr; use databend_common_expression::FunctionContext; @@ -45,6 +47,8 @@ use super::StateAddrs; use crate::aggregates::assert_arguments; use crate::aggregates::assert_unary_arguments; use crate::aggregates::assert_variadic_params; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::with_simple_no_number_mapped_type; use crate::BUILTIN_FUNCTIONS; @@ -87,7 +91,7 @@ macro_rules! with_bitmap_agg_mapped_type { } trait BitmapAggResult: Send + Sync + 'static { - fn merge_result(place: StateAddr, builder: &mut ColumnBuilder) -> Result<()>; + fn merge_result(place: AggrState, builder: &mut ColumnBuilder) -> Result<()>; fn return_type() -> Result; } @@ -97,7 +101,7 @@ struct BitmapCountResult; struct BitmapRawResult; impl BitmapAggResult for BitmapCountResult { - fn merge_result(place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let builder = UInt64Type::try_downcast_builder(builder).unwrap(); let state = place.get::(); builder.push(state.rb.as_ref().map(|rb| rb.len()).unwrap_or(0)); @@ -110,7 +114,7 @@ impl BitmapAggResult for BitmapCountResult { } impl BitmapAggResult for BitmapRawResult { - fn merge_result(place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let builder = BitmapType::try_downcast_builder(builder).unwrap(); let state = place.get::(); if let Some(rb) = state.rb.as_ref() { @@ -214,17 +218,17 @@ where AGG::return_type() } - fn init_state(&self, place: super::StateAddr) { + fn init_state(&self, place: AggrState) { place.write(BitmapAggState::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -261,22 +265,21 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { let column = BitmapType::try_downcast_column(&columns[0]).unwrap(); - for (data, place) in column.iter().zip(places.iter()) { - let addr = place.next(offset); - let state = addr.get::(); + for (data, addr) in column.iter().zip(places.iter().cloned()) { + let state = AggrState::new(addr, loc).get::(); let rb = deserialize_bitmap(data)?; state.add::(rb); } Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let column = BitmapType::try_downcast_column(&columns[0]).unwrap(); let state = place.get::(); if let Some(data) = BitmapType::index_column(&column, row) { @@ -286,7 +289,7 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); // flag indicate where bitmap is none let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; @@ -297,7 +300,7 @@ where Ok(()) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let flag = reader[0]; @@ -309,7 +312,7 @@ where Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); @@ -319,7 +322,7 @@ where Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { AGG::merge_result(place, builder) } @@ -327,7 +330,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } @@ -432,17 +435,17 @@ where self.inner.return_type() } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { self.inner.init_state(place); } - fn state_layout(&self) -> Layout { - self.inner.state_layout() + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.inner.register_state(registry); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, @@ -459,7 +462,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -472,29 +475,29 @@ where let input = [column]; self.inner - .accumulate_keys(new_places_slice, offset, input.as_slice().into(), row_size) + .accumulate_keys(new_places_slice, loc, input.as_slice().into(), row_size) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { if self.filter_row(columns, row)? { return self.inner.accumulate_row(place, columns, row); } Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.inner.serialize(place, writer) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { self.inner.merge(place, reader) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.inner.merge_states(place, rhs) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { self.inner.merge_result(place, builder) } } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index c3bdd668f4984..4a85820722a87 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -23,6 +23,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -38,7 +40,7 @@ use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_function_factory::CombinatorDescription; use super::aggregator_common::assert_variadic_arguments; use super::AggregateCountFunction; -use super::StateAddr; +use crate::aggregates::AggrState; #[derive(Clone)] pub struct AggregateDistinctCombinator { @@ -50,6 +52,22 @@ pub struct AggregateDistinctCombinator { _state: PhantomData, } +impl AggregateDistinctCombinator { + fn get_state(place: AggrState) -> &mut State { + place + .addr + .next(place.loc[0].into_custom().unwrap().1) + .get::() + } + + fn set_state(place: AggrState, state: State) { + place + .addr + .next(place.loc[0].into_custom().unwrap().1) + .write_state(state); + } +} + impl AggregateFunction for AggregateDistinctCombinator where State: DistinctStateFunc { @@ -61,60 +79,53 @@ where State: DistinctStateFunc self.nested.return_type() } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); - let layout = Layout::new::(); - let nested_place = place.next(layout.size()); - self.nested.init_state(nested_place); + fn init_state(&self, place: AggrState) { + Self::set_state(place, State::new()); + self.nested.init_state(place.remove_first_loc()); } - fn state_layout(&self) -> Layout { - let layout = Layout::new::(); - - let nested = self.nested.state_layout(); - Layout::from_size_align(layout.size() + nested.size(), layout.align()).unwrap() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); + self.nested.register_state(registry); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { - let state = place.get::(); + let state = Self::get_state(place); state.batch_add(columns, validity, input_rows) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { - let state = place.get::(); + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { + let state = Self::get_state(place); state.add(columns, row) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::(); + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + let state = Self::get_state(place); state.serialize(writer) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + let state = Self::get_state(place); let rhs = State::deserialize(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + let state = Self::get_state(place); + let other = Self::get_state(rhs); state.merge(other) } - #[allow(unused_mut)] - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - - let layout = Layout::new::(); - let nested_place = place.next(layout.size()); + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + let state = Self::get_state(place); + let nested_place = place.remove_first_loc(); // faster path for count if self.nested.name() == "AggregateCountFunction" { @@ -141,14 +152,12 @@ where State: DistinctStateFunc true } - unsafe fn drop_state(&self, place: StateAddr) { - let state = place.get::(); + unsafe fn drop_state(&self, place: AggrState) { + let state = Self::get_state(place); std::ptr::drop_in_place(state); if self.nested.need_manual_drop_state() { - let layout = Layout::new::(); - let nested_place = place.next(layout.size()); - self.nested.drop_state(nested_place); + self.nested.drop_state(place.remove_first_loc()); } } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 490976762be1f..45abaabe5d07e 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::fmt; use std::sync::Arc; @@ -22,6 +21,7 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -30,6 +30,8 @@ use databend_common_expression::Scalar; use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionCreator; use crate::aggregates::aggregate_function_factory::CombinatorDescription; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; use crate::aggregates::StateAddrs; @@ -92,17 +94,17 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.return_type() } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { self.nested.init_state(place); } - fn state_layout(&self) -> Layout { - self.nested.state_layout() + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.nested.register_state(registry); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, @@ -125,7 +127,7 @@ impl AggregateFunction for AggregateIfCombinator { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -137,10 +139,10 @@ impl AggregateFunction for AggregateIfCombinator { let new_places_slice = new_places.as_slice(); self.nested - .accumulate_keys(new_places_slice, offset, (&columns).into(), row_size) + .accumulate_keys(new_places_slice, loc, (&columns).into(), row_size) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let predicate: Bitmap = BooleanType::try_downcast_column(&columns[self.argument_len - 1]).unwrap(); if predicate.get_bit(row) { @@ -150,19 +152,19 @@ impl AggregateFunction for AggregateIfCombinator { Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.nested.serialize(place, writer) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { self.nested.merge(place, reader) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { self.nested.merge_result(place, builder) } @@ -170,7 +172,7 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.need_manual_drop_state() } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { self.nested.drop_state(place); } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 5163b0dee77b0..88cece53cb8cb 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::fmt; use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::AggrStateRegistry; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -27,6 +27,8 @@ use super::AggregateFunctionFactory; use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionCreator; use crate::aggregates::aggregate_function_factory::CombinatorDescription; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; @@ -50,9 +52,7 @@ impl AggregateStateCombinator { .join(", "); let name = format!("StateCombinator({nested_name}, {arg_name})"); - let nested = AggregateFunctionFactory::instance().get(nested_name, params, arguments)?; - Ok(Arc::new(AggregateStateCombinator { name, nested })) } @@ -70,21 +70,17 @@ impl AggregateFunction for AggregateStateCombinator { Ok(DataType::Binary) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { self.nested.init_state(place); } - fn is_state(&self) -> bool { - true - } - - fn state_layout(&self) -> Layout { - self.nested.state_layout() + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.nested.register_state(registry); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, @@ -95,34 +91,35 @@ impl AggregateFunction for AggregateStateCombinator { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, input_rows: usize, ) -> Result<()> { self.nested - .accumulate_keys(places, offset, columns, input_rows) + .accumulate_keys(places, loc, columns, input_rows) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { self.nested.accumulate_row(place, columns, row) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.nested.serialize(place, writer) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + #[inline] + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { self.nested.merge(place, reader) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let str_builder = builder.as_binary_mut().unwrap(); - self.serialize(place, &mut str_builder.data)?; - str_builder.commit_row(); + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + let builder = builder.as_binary_mut().unwrap(); + self.nested.serialize(place, &mut builder.data)?; + builder.commit_row(); Ok(()) } @@ -130,7 +127,7 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.need_manual_drop_state() } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { self.nested.drop_state(place); } diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index 7cabda3a9d1ad..282b25dae8461 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -22,6 +22,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::utils::column_merge_validity; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -33,6 +35,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; struct AggregateCountState { count: u64, @@ -74,19 +78,19 @@ impl AggregateFunction for AggregateCountFunction { Ok(DataType::Number(NumberDataType::UInt64)) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(|| AggregateCountState { count: 0 }); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } // columns may be nullable // if not we use validity as the null signs fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, @@ -110,7 +114,7 @@ impl AggregateFunction for AggregateCountFunction { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -124,9 +128,9 @@ impl AggregateFunction for AggregateCountFunction { if v.null_count() == v.len() { return Ok(()); } - for (valid, place) in v.iter().zip(places.iter()) { + for (valid, &place) in v.iter().zip(places.iter()) { if valid { - let state = place.next(offset).get::(); + let state = AggrState::new(place, loc).get::(); state.count += 1; } } @@ -134,7 +138,7 @@ impl AggregateFunction for AggregateCountFunction { _ => { for place in places { - let state = place.next(offset).get::(); + let state = AggrState::new(*place, loc).get::(); state.count += 1; } } @@ -143,25 +147,25 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } - fn accumulate_row(&self, place: StateAddr, _columns: InputColumns, _row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, _columns: InputColumns, _row: usize) -> Result<()> { let state = place.get::(); state.count += 1; Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, &state.count) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let other: u64 = borsh_deserialize_state(reader)?; state.count += other; Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.count += other.count; @@ -171,13 +175,13 @@ impl AggregateFunction for AggregateCountFunction { fn batch_merge_result( &self, places: &[StateAddr], - offset: usize, + loc: Box<[AggrStateLoc]>, builder: &mut ColumnBuilder, ) -> Result<()> { match builder { ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) => { for place in places { - let state = place.next(offset).get::(); + let state = AggrState::new(*place, &loc).get::(); builder.push(state.count); } } @@ -186,7 +190,7 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { match builder { ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) => { let state = place.get::(); diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index 230b964b24b1b..ea90a7e463866 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -29,6 +29,8 @@ use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; use databend_common_expression::types::ValueType; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -39,6 +41,8 @@ use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::aggregator_common::assert_binary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; @@ -149,7 +153,7 @@ where Ok(DataType::Number(NumberDataType::Float64)) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(|| AggregateCovarianceState { count: 0, left_mean: 0.0, @@ -158,13 +162,15 @@ where }); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom( + Layout::new::(), + )); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -197,7 +203,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -206,15 +212,14 @@ where left.iter().zip(right.iter()).zip(places.iter()).for_each( |((left_val, right_val), place)| { - let place = place.next(offset); - let state = place.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(left_val.as_(), right_val.as_()); }, ); Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let left = NumberType::::try_downcast_column(&columns[0]).unwrap(); let right = NumberType::::try_downcast_column(&columns[1]).unwrap(); @@ -226,19 +231,19 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: AggregateCovarianceState = borsh_deserialize_state(reader)?; state.merge(&rhs); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other); @@ -246,7 +251,7 @@ where } #[allow(unused_mut)] - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); let builder = NumberType::::try_downcast_builder(builder).unwrap(); builder.push(R::apply(state).into()); diff --git a/src/query/functions/src/aggregates/aggregate_function_factory.rs b/src/query/functions/src/aggregates/aggregate_function_factory.rs index e8cbd31ef2bbc..78c669ba4c317 100644 --- a/src/query/functions/src/aggregates/aggregate_function_factory.rs +++ b/src/query/functions/src/aggregates/aggregate_function_factory.rs @@ -189,33 +189,30 @@ impl AggregateFunctionFactory { return Ok(agg); } - if !arguments.is_empty() && arguments.iter().any(|f| f.is_nullable_or_null()) { - let (new_params, new_arguments) = match name.to_lowercase().strip_suffix(STATE_SUFFIX) { - Some(_) => (params.clone(), arguments.clone()), - None => { - let new_params = AggregateFunctionCombinatorNull::transform_params(¶ms)?; - let new_arguments = - AggregateFunctionCombinatorNull::transform_arguments(&arguments)?; - (new_params, new_arguments) - } - }; - - let nested = self.get_impl(name, new_params, new_arguments, &mut features)?; - let agg = AggregateFunctionCombinatorNull::try_create( - name, - params, - arguments, - nested, - features.clone(), - )?; - if or_null { - return AggregateFunctionOrNullAdaptor::create(agg, features); + if arguments.iter().all(|f| !f.is_nullable_or_null()) { + let agg = self.get_impl(name, params, arguments, &mut features)?; + return if or_null { + AggregateFunctionOrNullAdaptor::create(agg, features) } else { - return Ok(agg); - } + Ok(agg) + }; } - let agg = self.get_impl(name, params, arguments, &mut features)?; + let nested = if name.to_lowercase().strip_suffix(STATE_SUFFIX).is_some() { + self.get_impl(name, params.clone(), arguments.clone(), &mut features)? + } else { + let new_params = AggregateFunctionCombinatorNull::transform_params(¶ms)?; + let new_arguments = AggregateFunctionCombinatorNull::transform_arguments(&arguments)?; + self.get_impl(name, new_params, new_arguments, &mut features)? + }; + + let agg = AggregateFunctionCombinatorNull::try_create( + name, + params, + arguments, + nested, + features.clone(), + )?; if or_null { AggregateFunctionOrNullAdaptor::create(agg, features) } else { diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index 0365e989f7974..85a527cc8578d 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -26,6 +26,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; use databend_common_expression::types::*; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -38,6 +40,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::assert_unary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; #[derive(BorshSerialize, BorshDeserialize, Debug)] @@ -144,17 +148,17 @@ where Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); + fn init_state(&self, place: AggrState) { + place.write(State::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, _validity: Option<&Bitmap>, _input_rows: usize, @@ -175,7 +179,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -186,8 +190,7 @@ where column_iter .zip(nullable_column.validity.iter().zip(places.iter())) .for_each(|(v, (valid, place))| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); if valid { state.add(Some(v.clone())) } else { @@ -199,8 +202,7 @@ where let column = T::try_downcast_column(&columns[0]).unwrap(); let column_iter = T::iter_column(&column); column_iter.zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(Some(v.clone())) }); } @@ -209,7 +211,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); match &columns[0] { Column::Nullable(box nullable_column) => { @@ -232,25 +234,25 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder) } @@ -259,7 +261,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 80f81e23dc823..6c80c7984eb66 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -29,6 +29,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; use databend_common_expression::types::*; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -40,6 +42,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::assert_binary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; pub trait BinaryScalarStateFunc: @@ -209,17 +213,17 @@ where Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); + fn init_state(&self, place: AggrState) { + place.write(State::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, _validity: Option<&Bitmap>, _input_rows: usize, @@ -234,7 +238,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -246,8 +250,7 @@ where for (k, (v, (valid, place))) in key_column_iter.zip(val_column_iter.zip(validity.iter().zip(places.iter()))) { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); if valid { state.add(Some((k, v.clone())))?; } else { @@ -256,8 +259,7 @@ where } } else { for (k, (v, place)) in key_column_iter.zip(val_column_iter.zip(places.iter())) { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(Some((k, v.clone())))?; } } @@ -265,7 +267,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); let (key_column, val_column, validity) = self.downcast_columns(columns)?; @@ -285,25 +287,25 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder) } @@ -312,7 +314,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 522e6dd4824ae..d07d9bd045a42 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -21,11 +21,15 @@ use databend_common_expression::types::AnyType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use super::aggregate_function::AggregateFunction; use super::StateAddr; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; #[derive(Clone)] pub struct AggregateNullResultFunction { @@ -47,15 +51,15 @@ impl AggregateFunction for AggregateNullResultFunction { Ok(self.data_type.clone()) } - fn init_state(&self, __place: StateAddr) {} + fn init_state(&self, _place: AggrState) {} - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - __place: StateAddr, + _place: AggrState, _columns: InputColumns, _validity: Option<&Bitmap>, _input_rows: usize, @@ -66,30 +70,30 @@ impl AggregateFunction for AggregateNullResultFunction { fn accumulate_keys( &self, _places: &[StateAddr], - _offset: usize, + _loc: &[AggrStateLoc], _columns: InputColumns, _input_rows: usize, ) -> Result<()> { Ok(()) } - fn accumulate_row(&self, _place: StateAddr, _columns: InputColumns, _row: usize) -> Result<()> { + fn accumulate_row(&self, _place: AggrState, _columns: InputColumns, _row: usize) -> Result<()> { Ok(()) } - fn serialize(&self, _place: StateAddr, _writer: &mut Vec) -> Result<()> { + fn serialize(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { Ok(()) } - fn merge(&self, _place: StateAddr, _reader: &mut &[u8]) -> Result<()> { + fn merge(&self, _place: AggrState, _reader: &mut &[u8]) -> Result<()> { Ok(()) } - fn merge_states(&self, _place: StateAddr, _rhs: StateAddr) -> Result<()> { + fn merge_states(&self, _place: AggrState, _rhs: AggrState) -> Result<()> { Ok(()) } - fn merge_result(&self, _place: StateAddr, array: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, _place: AggrState, array: &mut ColumnBuilder) -> Result<()> { AnyType::push_default(array); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index 0cd2ab6110865..d999a6c448045 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -28,6 +28,8 @@ use databend_common_expression::types::number::*; use databend_common_expression::types::Bitmap; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::Expr; use databend_common_expression::FunctionContext; @@ -42,6 +44,8 @@ use super::borsh_serialize_state; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::assert_params; use crate::aggregates::assert_unary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; use crate::aggregates::StateAddr; @@ -302,15 +306,15 @@ where T: Number + AsPrimitive fn return_type(&self) -> Result { Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(QuantileTDigestState::new) } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -334,7 +338,7 @@ where T: Number + AsPrimitive Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); let v = NumberType::::index_column(&column, row); if let Some(v) = v { @@ -346,36 +350,35 @@ where T: Number + AsPrimitive fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); column.iter().zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(v.as_(), None) }); Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let mut rhs: QuantileTDigestState = borsh_deserialize_state(reader)?; state.merge(&mut rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder, self.levels.clone()) } @@ -384,7 +387,7 @@ where T: Number + AsPrimitive true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index c46226ec64a9f..761f49714974b 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -26,6 +26,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; use databend_common_expression::with_unsigned_integer_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::Expr; use databend_common_expression::FunctionContext; @@ -41,6 +43,8 @@ use crate::aggregates::aggregate_quantile_tdigest::MEDIAN; use crate::aggregates::aggregate_quantile_tdigest::QUANTILE; use crate::aggregates::assert_binary_arguments; use crate::aggregates::assert_params; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; use crate::aggregates::StateAddr; @@ -77,15 +81,15 @@ where fn return_type(&self) -> Result { Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(QuantileTDigestState::new) } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -112,7 +116,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); let weighted = NumberType::::try_downcast_column(&columns[1]).unwrap(); let value = unsafe { column.get_unchecked(row) }; @@ -125,7 +129,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -136,30 +140,29 @@ where .zip(weighted.iter()) .zip(places.iter()) .for_each(|((value, weight), place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(value.as_(), Some(weight.as_())) }); Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let mut rhs: QuantileTDigestState = borsh_deserialize_state(reader)?; state.merge(&mut rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder, self.levels.clone()) } @@ -168,7 +171,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index d3d2d427beb9e..afc12910c58a8 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -25,6 +25,8 @@ use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -36,6 +38,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; #[derive(BorshSerialize, BorshDeserialize)] struct AggregateRetentionState { @@ -70,17 +74,19 @@ impl AggregateFunction for AggregateRetentionFunction { )))) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(|| AggregateRetentionState { events: 0 }); } - fn state_layout(&self) -> std::alloc::Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom( + Layout::new::(), + )); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, _validity: Option<&Bitmap>, input_rows: usize, @@ -103,7 +109,7 @@ impl AggregateFunction for AggregateRetentionFunction { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -112,8 +118,7 @@ impl AggregateFunction for AggregateRetentionFunction { .map(|col| BooleanType::try_downcast_column(col).unwrap()) .collect::>(); for (row, place) in places.iter().enumerate() { - let place = place.next(offset); - let state = place.get::(); + let state = AggrState::new(*place, loc).get::(); for j in 0..self.events_size { if new_columns[j as usize].get_bit(row) { state.add(j); @@ -123,7 +128,7 @@ impl AggregateFunction for AggregateRetentionFunction { Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); let new_columns = columns .iter() @@ -137,19 +142,19 @@ impl AggregateFunction for AggregateRetentionFunction { Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: AggregateRetentionState = borsh_deserialize_state(reader)?; state.merge(&rhs); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other); @@ -157,7 +162,7 @@ impl AggregateFunction for AggregateRetentionFunction { } #[allow(unused_mut)] - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); let builder = builder.as_array_mut().unwrap(); let inner = builder diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index f3078976cdf84..4a98c4415cd72 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -26,6 +26,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::GeometryType; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; @@ -49,6 +51,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::assert_unary_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; #[derive(BorshSerialize, BorshDeserialize, Debug)] @@ -208,17 +212,17 @@ where Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()); + fn init_state(&self, place: AggrState) { + place.write(State::new); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, _validity: Option<&Bitmap>, _input_rows: usize, @@ -242,7 +246,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -253,8 +257,7 @@ where column_iter .zip(nullable_column.validity.iter().zip(places.iter())) .for_each(|(v, (valid, place))| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); if valid { state.add(Some(v.clone())) } else { @@ -266,8 +269,7 @@ where if let Some(column) = T::try_downcast_column(&columns[0]) { let column_iter = T::iter_column(&column); column_iter.zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.add(Some(v.clone())) }); } @@ -277,7 +279,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); match &columns[0] { Column::Nullable(box nullable_column) => { @@ -301,25 +303,25 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_deserialize_state(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); state.merge_result(builder) } @@ -328,7 +330,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 2dc27dff33c9e..daa8c7dd77f1f 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -24,6 +24,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::StringType; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::InputColumns; use databend_common_expression::Scalar; @@ -33,6 +35,8 @@ use super::borsh_deserialize_state; use super::borsh_serialize_state; use super::StateAddr; use crate::aggregates::assert_variadic_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; #[derive(BorshSerialize, BorshDeserialize, Debug)] @@ -55,19 +59,19 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(DataType::String) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(|| StringAggState { values: String::new(), }); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -96,22 +100,21 @@ impl AggregateFunction for AggregateStringAggFunction { fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { let column = StringType::try_downcast_column(&columns[0]).unwrap(); let column_iter = StringType::iter_column(&column); column_iter.zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); + let state = AggrState::new(*place, loc).get::(); state.values.push_str(v); state.values.push_str(&self.delimiter); }); Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let column = StringType::try_downcast_column(&columns[0]).unwrap(); let v = StringType::index_column(&column, row); if let Some(v) = v { @@ -122,27 +125,27 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); borsh_serialize_state(writer, state)?; Ok(()) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: StringAggState = borsh_deserialize_state(reader)?; state.values.push_str(&rhs.values); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); state.values.push_str(&other.values); Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); let builder = StringType::try_downcast_builder(builder).unwrap(); if !state.values.is_empty() { @@ -158,7 +161,7 @@ impl AggregateFunction for AggregateStringAggFunction { true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 7f39a503c956e..2cd7bd8c03867 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -36,6 +36,7 @@ use super::assert_unary_arguments; use super::FunctionData; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::aggregate_unary::UnaryState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateUnaryFunction; pub trait SumState: BorshSerialize + BorshDeserialize + Send + Sync + Default + 'static { @@ -47,7 +48,7 @@ pub trait SumState: BorshSerialize + BorshDeserialize + Send + Sync + Default + fn accumulate(&mut self, column: &Column, validity: Option<&Bitmap>) -> Result<()>; fn accumulate_row(&mut self, column: &Column, row: usize) -> Result<()>; - fn accumulate_keys(places: &[StateAddr], offset: usize, columns: &Column) -> Result<()>; + fn accumulate_keys(places: &[StateAddr], loc: &[AggrStateLoc], columns: &Column) -> Result<()>; fn merge_result( &mut self, diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 66aaa29e2ff18..b6e4db40bd397 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -25,6 +25,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::DecimalSize; use databend_common_expression::types::ValueType; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::ColumnBuilder; @@ -32,6 +34,9 @@ use databend_common_expression::InputColumns; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; + pub trait UnaryState: Send + Sync + Default + borsh::BorshSerialize + borsh::BorshDeserialize where @@ -201,17 +206,17 @@ where Ok(self.return_type.clone()) } - fn init_state(&self, place: StateAddr) { - place.write_state(S::default()) + fn init_state(&self, place: AggrState) { + place.write(S::default); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -222,7 +227,7 @@ where state.add_batch(column, validity, self.function_data.as_deref()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let column = T::try_downcast_column(&columns[0]).unwrap(); let value = T::index_column(&column, row); @@ -234,14 +239,14 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { let column = T::try_downcast_column(&columns[0]).unwrap(); for (i, place) in places.iter().enumerate() { - let state: &mut S = place.next(offset).get::(); + let state: &mut S = AggrState::new(*place, loc).get::(); state.add( T::index_column(&column, i).unwrap(), self.function_data.as_deref(), @@ -251,24 +256,24 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state: &mut S = place.get::(); Ok(borsh::to_writer(writer, state)?) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state: &mut S = place.get::(); let rhs = S::deserialize_reader(reader)?; state.merge(&rhs) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state: &mut S = place.get::(); let other: &mut S = rhs.get::(); state.merge(other) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state: &mut S = place.get::(); self.do_merge_result(state, builder) } @@ -276,11 +281,11 @@ where fn batch_merge_result( &self, places: &[StateAddr], - offset: usize, + loc: Box<[AggrStateLoc]>, builder: &mut ColumnBuilder, ) -> Result<()> { for place in places { - let state: &mut S = place.next(offset).get::(); + let state: &mut S = AggrState::new(*place, &loc).get::(); self.do_merge_result(state, builder)?; } Ok(()) @@ -290,7 +295,7 @@ where self.need_drop } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index 88ab9c04cf4cb..e242869f32623 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -36,6 +36,8 @@ use databend_common_expression::types::NumberType; use databend_common_expression::types::TimestampType; use databend_common_expression::types::ValueType; use databend_common_expression::with_integer_mapped_type; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::Expr; use databend_common_expression::FunctionContext; @@ -51,6 +53,8 @@ use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::assert_unary_params; use crate::aggregates::assert_variadic_arguments; +use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; use crate::aggregates::AggregateFunction; use crate::BUILTIN_FUNCTIONS; @@ -177,17 +181,19 @@ where Ok(DataType::Number(NumberDataType::UInt8)) } - fn init_state(&self, place: StateAddr) { + fn init_state(&self, place: AggrState) { place.write(AggregateWindowFunnelState::::new); } - fn state_layout(&self) -> Layout { - Layout::new::>() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::< + AggregateWindowFunnelState, + >())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -235,7 +241,7 @@ where fn accumulate_keys( &self, places: &[StateAddr], - offset: usize, + loc: &[AggrStateLoc], columns: InputColumns, _input_rows: usize, ) -> Result<()> { @@ -248,7 +254,7 @@ where let tcolumn = T::try_downcast_column(&columns[0]).unwrap(); for ((row, timestamp), place) in T::iter_column(&tcolumn).enumerate().zip(places.iter()) { - let state = (place.next(offset)).get::>(); + let state = AggrState::new(*place, loc).get::>(); let timestamp = T::to_owned_scalar(timestamp); for (i, filter) in dcolumns.iter().enumerate() { if filter.get_bit(row) { @@ -259,7 +265,7 @@ where Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let tcolumn = T::try_downcast_column(&columns[0]).unwrap(); let timestamp = unsafe { T::index_column_unchecked(&tcolumn, row) }; let timestamp = T::to_owned_scalar(timestamp); @@ -274,19 +280,19 @@ where Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::>(); borsh_serialize_state(writer, state) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::>(); let mut rhs: AggregateWindowFunnelState = borsh_deserialize_state(reader)?; state.merge(&mut rhs); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::>(); let other = rhs.get::>(); state.merge(other); @@ -294,7 +300,7 @@ where } #[allow(unused_mut)] - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let builder = UInt8Type::try_downcast_builder(builder).unwrap(); let result = self.get_event_level(place); builder.push(result); @@ -305,7 +311,7 @@ where true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::>(); std::ptr::drop_in_place(state); } @@ -370,7 +376,7 @@ where /// The level path must be 1---2---3---...---check_events_size, find the max event level that satisfied the path in the sliding window. /// If found, returns the max event level, else return 0. /// The Algorithm complexity is O(n). - fn get_event_level(&self, place: StateAddr) -> u8 { + fn get_event_level(&self, place: AggrState) -> u8 { let state = place.get::>(); if state.events_list.is_empty() { return 0; diff --git a/src/query/functions/src/aggregates/aggregator_common.rs b/src/query/functions/src/aggregates/aggregator_common.rs index 775a057c7112b..9a2bceaa122d4 100644 --- a/src/query/functions/src/aggregates/aggregator_common.rs +++ b/src/query/functions/src/aggregates/aggregator_common.rs @@ -24,10 +24,13 @@ use databend_common_expression::types::DataType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::Scalar; +use databend_common_expression::StateAddr; +use super::get_states_layout; +use super::AggrState; use super::AggregateFunctionFactory; use super::AggregateFunctionRef; -use super::StateAddr; +use crate::aggregates::StatesLayout; pub fn assert_unary_params(name: D, actual: usize) -> Result<()> { if actual != 1 { @@ -109,18 +112,29 @@ pub fn assert_variadic_arguments( struct EvalAggr { addr: StateAddr, + state_layout: StatesLayout, _arena: Bump, func: AggregateFunctionRef, } impl EvalAggr { fn new(func: AggregateFunctionRef) -> Self { + let funcs = [func]; + let state_layout = get_states_layout(&funcs).unwrap(); + let [func] = funcs; + let _arena = Bump::new(); - let place = _arena.alloc_layout(func.state_layout()); - let addr = place.into(); - func.init_state(addr); + let addr = _arena.alloc_layout(state_layout.layout).into(); + + let state = AggrState::new(addr, &state_layout.states_loc[0]); + func.init_state(state); - Self { _arena, func, addr } + Self { + addr, + state_layout, + _arena, + func, + } } } @@ -129,7 +143,8 @@ impl Drop for EvalAggr { drop_guard(move || { if self.func.need_manual_drop_state() { unsafe { - self.func.drop_state(self.addr); + self.func + .drop_state(AggrState::new(self.addr, &self.state_layout.states_loc[0])); } } }) @@ -141,6 +156,16 @@ pub fn eval_aggr( params: Vec, columns: &[Column], rows: usize, +) -> Result<(Column, DataType)> { + eval_aggr_for_test(name, params, columns, rows, false) +} + +pub fn eval_aggr_for_test( + name: &str, + params: Vec, + columns: &[Column], + rows: usize, + with_serialize: bool, ) -> Result<(Column, DataType)> { let factory = AggregateFunctionFactory::instance(); let arguments = columns.iter().map(|x| x.data_type()).collect(); @@ -149,9 +174,16 @@ pub fn eval_aggr( let data_type = func.return_type()?; let eval = EvalAggr::new(func.clone()); - func.accumulate(eval.addr, columns.into(), None, rows)?; + let state = AggrState::new(eval.addr, &eval.state_layout.states_loc[0]); + func.accumulate(state, columns.into(), None, rows)?; + if with_serialize { + let mut buf = vec![]; + func.serialize(state, &mut buf)?; + func.init_state(state); + func.merge(state, &mut buf.as_slice())?; + } let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); - func.merge_result(eval.addr, &mut builder)?; + func.merge_result(state, &mut builder)?; Ok((builder.build(), data_type)) } diff --git a/src/query/functions/tests/it/aggregates/agg.rs b/src/query/functions/tests/it/aggregates/agg.rs index fd4bc22e50e67..454045cac3186 100644 --- a/src/query/functions/tests/it/aggregates/agg.rs +++ b/src/query/functions/tests/it/aggregates/agg.rs @@ -14,17 +14,19 @@ use std::io::Write; +use databend_common_exception::Result; use databend_common_expression::types::decimal::Decimal128Type; use databend_common_expression::types::number::Int64Type; use databend_common_expression::types::number::UInt64Type; use databend_common_expression::types::BitmapType; use databend_common_expression::types::BooleanType; +use databend_common_expression::types::DataType; use databend_common_expression::types::DecimalSize; use databend_common_expression::types::StringType; use databend_common_expression::types::TimestampType; use databend_common_expression::Column; use databend_common_expression::FromData; -use databend_common_functions::aggregates::eval_aggr; +use databend_common_functions::aggregates::eval_aggr_for_test; use goldenfile::Mint; use itertools::Itertools; use roaring::RoaringTreemap; @@ -33,6 +35,15 @@ use super::run_agg_ast; use super::simulate_two_groups_group_by; use super::AggregationSimulator; +fn eval_aggr( + name: &str, + params: Vec, + columns: &[Column], + rows: usize, +) -> Result<(Column, DataType)> { + eval_aggr_for_test(name, params, columns, rows, true) +} + #[test] fn test_agg() { let mut mint = Mint::new("tests/it/aggregates/testdata"); @@ -60,7 +71,10 @@ fn test_agg() { test_agg_quantile_disc(file, eval_aggr); test_agg_quantile_cont(file, eval_aggr); test_agg_quantile_tdigest(file, eval_aggr); - test_agg_quantile_tdigest_weighted(file, eval_aggr); + // FIXME + test_agg_quantile_tdigest_weighted(file, |name, params, columns, rows| { + eval_aggr_for_test(name, params, columns, rows, false) + }); test_agg_median(file, eval_aggr); test_agg_median_tdigest(file, eval_aggr); test_agg_array_agg(file, eval_aggr); diff --git a/src/query/functions/tests/it/aggregates/agg_hashtable.rs b/src/query/functions/tests/it/aggregates/agg_hashtable.rs index b6e9786aa220c..2a3ef567e4c7c 100644 --- a/src/query/functions/tests/it/aggregates/agg_hashtable.rs +++ b/src/query/functions/tests/it/aggregates/agg_hashtable.rs @@ -31,6 +31,7 @@ use std::sync::Arc; use bumpalo::Bump; use databend_common_expression::block_debug::assert_block_value_sort_eq; +use databend_common_expression::get_states_layout; use databend_common_expression::types::ArgType; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; @@ -193,9 +194,11 @@ fn test_layout() { type S = DecimalSumState>; type M = DecimalSumState>; + let states_layout = get_states_layout(&[aggrs.clone()]).unwrap(); + assert_eq!( - aggrs.state_layout(), - Layout::from_size_align(24, 8).unwrap() + states_layout.layout, + Layout::from_size_align(17, 8).unwrap() ); assert_eq!(Layout::new::(), Layout::from_size_align(16, 8).unwrap()); assert_eq!(Layout::new::(), Layout::from_size_align(32, 8).unwrap()); diff --git a/src/query/functions/tests/it/aggregates/mod.rs b/src/query/functions/tests/it/aggregates/mod.rs index 813c59de6871a..0573d687d299d 100644 --- a/src/query/functions/tests/it/aggregates/mod.rs +++ b/src/query/functions/tests/it/aggregates/mod.rs @@ -20,9 +20,11 @@ use std::io::Write; use bumpalo::Bump; use comfy_table::Table; use databend_common_exception::Result; +use databend_common_expression::get_states_layout; use databend_common_expression::type_check; use databend_common_expression::types::AnyType; use databend_common_expression::types::DataType; +use databend_common_expression::AggrState; use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; @@ -190,30 +192,28 @@ pub fn simulate_two_groups_group_by( let func = factory.get(name, params, arguments)?; let data_type = func.return_type()?; + let states_layout = get_states_layout(&[func.clone()])?; + let loc = states_layout.states_loc[0].clone(); let arena = Bump::new(); // init state for two groups - let addr1 = arena.alloc_layout(func.state_layout()); - func.init_state(addr1.into()); - let addr2 = arena.alloc_layout(func.state_layout()); - func.init_state(addr2.into()); + let addr1 = arena.alloc_layout(states_layout.layout).into(); + let state1 = AggrState::new(addr1, &loc); + func.init_state(state1); + let addr2 = arena.alloc_layout(states_layout.layout).into(); + let state2 = AggrState::new(addr2, &loc); + func.init_state(state2); let places = (0..rows) - .map(|i| { - if i % 2 == 0 { - addr1.into() - } else { - addr2.into() - } - }) + .map(|i| if i % 2 == 0 { addr1 } else { addr2 }) .collect::>(); - func.accumulate_keys(&places, 0, columns.into(), rows)?; + func.accumulate_keys(&places, &loc, columns.into(), rows)?; let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); - func.merge_result(addr1.into(), &mut builder)?; - func.merge_result(addr2.into(), &mut builder)?; + func.merge_result(state1, &mut builder)?; + func.merge_result(state2, &mut builder)?; Ok((builder.build(), data_type)) } diff --git a/src/query/service/src/pipelines/builders/builder_aggregate.rs b/src/query/service/src/pipelines/builders/builder_aggregate.rs index 06748396f3e82..2b799d5073557 100644 --- a/src/query/service/src/pipelines/builders/builder_aggregate.rs +++ b/src/query/service/src/pipelines/builders/builder_aggregate.rs @@ -296,6 +296,8 @@ impl PipelineBuilder { max_spill_io_requests, )?; + log::debug!("aggregate states layout: {:?}", params.states_layout); + Ok(params) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs index dacfcc0826b08..55688a4347259 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_exchange_injector.rs @@ -86,7 +86,12 @@ fn scatter_payload(mut payload: Payload, buckets: usize) -> Result> let mut state = PayloadFlushState::default(); for _ in 0..buckets.capacity() { - let p = Payload::new(payload.arena.clone(), group_types.clone(), aggrs.clone()); + let p = Payload::new( + payload.arena.clone(), + group_types.clone(), + aggrs.clone(), + payload.states_layout.clone(), + ); buckets.push(p); } @@ -134,6 +139,7 @@ fn scatter_partitioned_payload( Arc::new(Bump::new()), group_types.clone(), aggrs.clone(), + partitioned_payload.states_layout.clone(), )); } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index 74b6cf41dbfd5..2ae3cc620b928 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -49,6 +49,7 @@ impl SerializedPayload { &self, group_types: Vec, aggrs: Vec>, + num_states: usize, radix_bits: u64, arena: Arc, need_init_entry: bool, @@ -57,7 +58,6 @@ impl SerializedPayload { let capacity = AggregateHashTable::get_capacity_for_count(rows_num); let config = HashTableConfig::default().with_initial_radix_bits(radix_bits); let mut state = ProbeState::default(); - let agg_len = aggrs.len(); let group_len = group_types.len(); let mut hashtable = AggregateHashTable::new_directly( group_types, @@ -68,10 +68,10 @@ impl SerializedPayload { need_init_entry, ); - let states_index: Vec = (0..agg_len).collect(); + let states_index: Vec = (0..num_states).collect(); let agg_states = InputColumns::new_block_proxy(&states_index, &self.data_block); - let group_index: Vec = (agg_len..(agg_len + group_len)).collect(); + let group_index: Vec = (num_states..(num_states + group_len)).collect(); let group_columns = InputColumns::new_block_proxy(&group_index, &self.data_block); let _ = hashtable.add_groups( @@ -90,11 +90,18 @@ impl SerializedPayload { &self, group_types: Vec, aggrs: Vec>, + num_states: usize, radix_bits: u64, arena: Arc, ) -> Result { - let hashtable = - self.convert_to_aggregate_table(group_types, aggrs, radix_bits, arena, false)?; + let hashtable = self.convert_to_aggregate_table( + group_types, + aggrs, + num_states, + radix_bits, + arena, + false, + )?; Ok(hashtable.payload) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs index 21571182f2449..ce73b65e164da 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; use std::sync::Arc; use databend_common_exception::Result; @@ -20,8 +19,9 @@ use databend_common_expression::types::DataType; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::DataSchemaRef; -use databend_common_functions::aggregates::get_layout_offsets; +use databend_common_functions::aggregates::get_states_layout; use databend_common_functions::aggregates::AggregateFunctionRef; +use databend_common_functions::aggregates::StatesLayout; use databend_common_sql::IndexType; use itertools::Itertools; @@ -35,8 +35,7 @@ pub struct AggregatorParams { // about function state memory layout // If there is no aggregate function, layout is None - pub layout: Option, - pub offsets_aggregate_states: Vec, + pub states_layout: Option, pub enable_experimental_aggregate_hashtable: bool, pub cluster_aggregator: bool, @@ -56,12 +55,11 @@ impl AggregatorParams { max_block_size: usize, max_spill_io_requests: usize, ) -> Result> { - let mut states_offsets: Vec = Vec::with_capacity(agg_funcs.len()); - let mut states_layout = None; - if !agg_funcs.is_empty() { - states_offsets = Vec::with_capacity(agg_funcs.len()); - states_layout = Some(get_layout_offsets(agg_funcs, &mut states_offsets)?); - } + let states_layout = if !agg_funcs.is_empty() { + Some(get_states_layout(agg_funcs)?) + } else { + None + }; Ok(Arc::new(AggregatorParams { input_schema, @@ -69,8 +67,7 @@ impl AggregatorParams { group_data_types, aggregate_functions: agg_funcs.to_vec(), aggregate_functions_arguments: agg_args.to_vec(), - layout: states_layout, - offsets_aggregate_states: states_offsets, + states_layout, enable_experimental_aggregate_hashtable, cluster_aggregator, max_block_size, @@ -97,4 +94,8 @@ impl AggregatorParams { .collect_vec(); DataBlock::new_from_columns(columns) } + + pub fn num_states(&self) -> usize { + self.aggregate_functions.len() + } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs index 89eaede89b97a..bdd17a88364fc 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs @@ -22,7 +22,6 @@ mod transform_aggregate_final; mod transform_aggregate_partial; mod transform_single_key; mod udaf_script; -mod utils; pub use aggregate_exchange_injector::AggregateInjector; pub use aggregate_meta::*; @@ -34,7 +33,6 @@ pub use transform_aggregate_partial::TransformPartialAggregate; pub use transform_single_key::FinalSingleStateAggregator; pub use transform_single_key::PartialSingleStateAggregator; pub use udaf_script::*; -pub use utils::*; pub use self::serde::*; use super::runtime_pool; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/new_transform_partition_bucket.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/new_transform_partition_bucket.rs index 39db4c52ca273..5c5cddc4258fd 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/new_transform_partition_bucket.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/new_transform_partition_bucket.rs @@ -366,6 +366,7 @@ impl NewTransformPartitionBucket { let p = payload.convert_to_partitioned_payload( self.params.group_data_types.clone(), self.params.aggregate_functions.clone(), + self.params.num_states(), 0, Arc::new(Bump::new()), )?; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs index beded12043100..048d7e6ed5a1c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs @@ -63,6 +63,7 @@ impl TransformFinalAggregate { let payload = payload.convert_to_partitioned_payload( self.params.group_data_types.clone(), self.params.aggregate_functions.clone(), + self.params.num_states(), 0, Arc::new(Bump::new()), )?; @@ -73,6 +74,7 @@ impl TransformFinalAggregate { agg_hashtable = Some(payload.convert_to_aggregate_table( self.params.group_data_types.clone(), self.params.aggregate_functions.clone(), + self.params.num_states(), 0, Arc::new(Bump::new()), true, diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index 481b7d092c2cf..039154393a598 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -182,10 +182,15 @@ impl TransformPartialAggregate { HashTable::AggregateHashTable(hashtable) => { let (params_columns, states_index) = if is_agg_index_block { let num_columns = block.num_columns(); - let functions_count = self.params.aggregate_functions.len(); + let states_count = self + .params + .states_layout + .as_ref() + .map(|layout| layout.states_loc.len()) + .unwrap_or(0); ( vec![], - (num_columns - functions_count..num_columns).collect::>(), + (num_columns - states_count..num_columns).collect::>(), ) } else { ( diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index 1de36a979b1fb..ab62453fb053e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::borrow::BorrowMut; use std::sync::Arc; use std::time::Instant; use std::vec; @@ -24,15 +22,13 @@ use databend_common_base::base::convert_number_size; use databend_common_catalog::plan::AggIndexMeta; use databend_common_exception::ErrorCode; use databend_common_exception::Result; -use databend_common_expression::types::DataType; -use databend_common_expression::BlockEntry; +use databend_common_expression::AggrState; use databend_common_expression::BlockMetaInfoDowncast; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::InputColumns; -use databend_common_expression::Scalar; -use databend_common_expression::Value; +use databend_common_expression::StatesLayout; use databend_common_functions::aggregates::AggregateFunctionRef; use databend_common_functions::aggregates::StateAddr; use databend_common_pipeline_core::processors::InputPort; @@ -47,7 +43,8 @@ use crate::pipelines::processors::transforms::aggregator::AggregatorParams; pub struct PartialSingleStateAggregator { #[allow(dead_code)] arena: Bump, - places: Vec, + addr: StateAddr, + states_layout: StatesLayout, arg_indices: Vec>, funcs: Vec, @@ -59,29 +56,26 @@ pub struct PartialSingleStateAggregator { impl PartialSingleStateAggregator { pub fn try_new(params: &Arc) -> Result { - assert!(!params.offsets_aggregate_states.is_empty()); - let arena = Bump::new(); - let layout = params - .layout - .ok_or_else(|| ErrorCode::LayoutError("layout shouldn't be None"))?; - - let place: StateAddr = arena.alloc_layout(layout).into(); - let temp_place: StateAddr = arena.alloc_layout(layout).into(); - let mut places = Vec::with_capacity(params.offsets_aggregate_states.len()); - - for (idx, func) in params.aggregate_functions.iter().enumerate() { - let arg_place = place.next(params.offsets_aggregate_states[idx]); - func.init_state(arg_place); - places.push(arg_place); - - let state_place = temp_place.next(params.offsets_aggregate_states[idx]); - func.init_state(state_place); + let state_layout = params + .states_layout + .as_ref() + .ok_or_else(|| ErrorCode::LayoutError("layout shouldn't be None"))? + .clone(); + + let addr: StateAddr = arena.alloc_layout(state_layout.layout).into(); + for (func, loc) in params + .aggregate_functions + .iter() + .zip(state_layout.states_loc.iter()) + { + func.init_state(AggrState::new(addr, loc)); } Ok(PartialSingleStateAggregator { arena, - places, + addr, + states_layout: state_layout, funcs: params.aggregate_functions.clone(), arg_indices: params.aggregate_functions_arguments.clone(), start: Instant::now(), @@ -108,17 +102,36 @@ impl AccumulatingTransform for PartialSingleStateAggregator { let block = block.consume_convert_to_full(); - for (idx, func) in self.funcs.iter().enumerate() { - let place = self.places[idx]; - if is_agg_index_block { - // Aggregation states are in the back of the block. - let agg_index = block.num_columns() - self.funcs.len() + idx; - let agg_state = block.get_by_offset(agg_index).value.as_column().unwrap(); - - func.batch_merge_single(place, agg_state)?; - } else { - let columns = - InputColumns::new_block_proxy(self.arg_indices[idx].as_slice(), &block); + if is_agg_index_block { + // Aggregation states are in the back of the block. + let states_indices = (block.num_columns() - self.states_layout.states_loc.len() + ..block.num_columns()) + .collect::>(); + let states = InputColumns::new_block_proxy(&states_indices, &block); + + for ((place, func), state) in self + .states_layout + .states_loc + .iter() + .map(|loc| AggrState::new(self.addr, loc)) + .zip(self.funcs.iter()) + .zip(states.iter()) + { + func.batch_merge_single(place, state)?; + } + } else { + for ((place, columns), func) in self + .states_layout + .states_loc + .iter() + .map(|loc| AggrState::new(self.addr, loc)) + .zip( + self.arg_indices + .iter() + .map(|indices| InputColumns::new_block_proxy(indices.as_slice(), &block)), + ) + .zip(self.funcs.iter()) + { func.accumulate(place, columns, None, block.num_rows())?; } } @@ -130,29 +143,37 @@ impl AccumulatingTransform for PartialSingleStateAggregator { } fn on_finish(&mut self, generate_data: bool) -> Result> { - let mut generate_data_block = vec![]; - - if generate_data { - let mut columns = Vec::with_capacity(self.funcs.len()); - - for (idx, func) in self.funcs.iter().enumerate() { - let place = self.places[idx]; - - let mut data = Vec::with_capacity(4); - func.serialize(place, &mut data)?; - columns.push(BlockEntry::new( - DataType::Binary, - Value::Scalar(Scalar::Binary(data)), - )); + let blocks = if generate_data { + let mut builders = self.states_layout.serialize_builders(1); + + for ((func, place), builder) in self + .funcs + .iter() + .zip( + self.states_layout + .states_loc + .iter() + .map(|loc| AggrState::new(self.addr, loc)), + ) + .zip(builders.iter_mut()) + { + func.serialize(place, &mut builder.data)?; + builder.commit_row(); } - generate_data_block = vec![DataBlock::new(columns, 1)]; - } + let columns = builders + .into_iter() + .map(|b| Column::Binary(b.build())) + .collect(); + vec![DataBlock::new_from_columns(columns)] + } else { + vec![] + }; // destroy states - for (place, func) in self.places.iter().zip(self.funcs.iter()) { + for (loc, func) in self.states_layout.states_loc.iter().zip(self.funcs.iter()) { if func.need_manual_drop_state() { - unsafe { func.drop_state(*place) } + unsafe { func.drop_state(AggrState::new(self.addr, loc)) } } } @@ -170,17 +191,16 @@ impl AccumulatingTransform for PartialSingleStateAggregator { convert_byte_size(self.bytes as _), ); - Ok(generate_data_block) + Ok(blocks) } } /// SELECT COUNT | SUM FROM table; pub struct FinalSingleStateAggregator { arena: Bump, - layout: Layout, - to_merge_data: Vec>, + states_layout: StatesLayout, + to_merge_data: Vec, funcs: Vec, - offsets_aggregate_states: Vec, } impl FinalSingleStateAggregator { @@ -189,38 +209,26 @@ impl FinalSingleStateAggregator { output: Arc, params: &Arc, ) -> Result> { - assert!(!params.offsets_aggregate_states.is_empty()); - let arena = Bump::new(); - let layout = params - .layout - .ok_or_else(|| ErrorCode::LayoutError("layout shouldn't be None"))?; + let states_layout = params + .states_layout + .as_ref() + .ok_or_else(|| ErrorCode::LayoutError("layout shouldn't be None"))? + .clone(); + + assert!(!states_layout.states_loc.is_empty()); Ok(AccumulatingTransformer::create( input, output, FinalSingleStateAggregator { arena, - layout, + states_layout, funcs: params.aggregate_functions.clone(), - to_merge_data: vec![vec![]; params.aggregate_functions.len()], - offsets_aggregate_states: params.offsets_aggregate_states.clone(), + to_merge_data: Vec::new(), }, )) } - - fn new_places(&self) -> Vec { - let place: StateAddr = self.arena.alloc_layout(self.layout).into(); - self.funcs - .iter() - .enumerate() - .map(|(idx, func)| { - let arg_place = place.next(self.offsets_aggregate_states[idx]); - func.init_state(arg_place); - arg_place - }) - .collect() - } } impl AccumulatingTransform for FinalSingleStateAggregator { @@ -229,54 +237,63 @@ impl AccumulatingTransform for FinalSingleStateAggregator { fn transform(&mut self, block: DataBlock) -> Result> { if !block.is_empty() { let block = block.consume_convert_to_full(); - - for (index, _) in self.funcs.iter().enumerate() { - let binary_array = block.get_by_offset(index).value.as_column().unwrap(); - self.to_merge_data[index].push(binary_array.clone()); - } + self.to_merge_data.push(block); } Ok(vec![]) } fn on_finish(&mut self, generate_data: bool) -> Result> { - let mut generate_data_block = vec![]; - - if generate_data { - let mut aggr_values = { - let mut builders = vec![]; - for func in &self.funcs { - let data_type = func.return_type()?; - builders.push(ColumnBuilder::with_capacity(&data_type, 1)); - } - builders - }; - - let main_places = self.new_places(); - for (index, func) in self.funcs.iter().enumerate() { - let main_place = main_places[index]; - for col in self.to_merge_data[index].iter() { - func.batch_merge_single(main_place, col)?; - } - let array = aggr_values[index].borrow_mut(); - func.merge_result(main_place, array)?; - } + if !generate_data { + return Ok(vec![]); + } - let mut columns = Vec::with_capacity(self.funcs.len()); - for builder in aggr_values { - columns.push(builder.build()); - } + let main_addr: StateAddr = self.arena.alloc_layout(self.states_layout.layout).into(); + + let main_places = self + .funcs + .iter() + .zip( + self.states_layout + .states_loc + .iter() + .map(|loc| AggrState::new(main_addr, loc)), + ) + .map(|(func, place)| { + func.init_state(place); + place + }) + .collect::>(); - // destroy states - for (place, func) in main_places.iter().zip(self.funcs.iter()) { - if func.need_manual_drop_state() { - unsafe { func.drop_state(*place) } - } + let mut result_builders = self + .funcs + .iter() + .map(|f| Ok(ColumnBuilder::with_capacity(&f.return_type()?, 1))) + .collect::>>()?; + + for (idx, ((func, place), builder)) in self + .funcs + .iter() + .zip(main_places.iter().copied()) + .zip(result_builders.iter_mut()) + .enumerate() + { + for block in self.to_merge_data.iter() { + let state = block.get_by_offset(idx).value.as_column().unwrap(); + func.batch_merge_single(place, state)?; } + func.merge_result(place, builder)?; + } - generate_data_block = vec![DataBlock::new_from_columns(columns)]; + let columns = result_builders.into_iter().map(|b| b.build()).collect(); + + // destroy states + for (place, func) in main_places.iter().copied().zip(self.funcs.iter()) { + if func.need_manual_drop_state() { + unsafe { func.drop_state(place) } + } } - Ok(generate_data_block) + Ok(vec![DataBlock::new_from_columns(columns)]) } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index 0b324b6c1ef63..0a50f9d09909d 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -28,13 +28,15 @@ use databend_common_expression::converts::arrow::ARROW_EXT_TYPE_VARIANT; use databend_common_expression::converts::arrow::EXTENSION_KEY; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::AggrState; +use databend_common_expression::AggrStateRegistry; +use databend_common_expression::AggrStateType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::InputColumns; -use databend_common_expression::StateAddr; use databend_common_functions::aggregates::AggregateFunction; use databend_common_sql::plans::UDFLanguage; use databend_common_sql::plans::UDFScriptCode; @@ -58,17 +60,17 @@ impl AggregateFunction for AggregateUdfScript { Ok(self.runtime.return_type()) } - fn init_state(&self, place: StateAddr) { - place.write_state(UdfAggState(self.init_state.0.clone())); + fn init_state(&self, place: AggrState) { + place.write(|| UdfAggState(self.init_state.0.clone())); } - fn state_layout(&self) -> Layout { - Layout::new::() + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(Layout::new::())); } fn accumulate( &self, - place: StateAddr, + place: AggrState, columns: InputColumns, validity: Option<&Bitmap>, _input_rows: usize, @@ -79,29 +81,29 @@ impl AggregateFunction for AggregateUdfScript { .runtime .accumulate(state, &input_batch) .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to accumulate: {e}")))?; - place.write_state(state); + place.write(|| state); Ok(()) } - fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + fn accumulate_row(&self, place: AggrState, columns: InputColumns, row: usize) -> Result<()> { let input_batch = self.create_input_batch_row(columns, row)?; let state = place.get::(); let state = self .runtime .accumulate(state, &input_batch) .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to accumulate_row: {e}")))?; - place.write_state(state); + place.write(|| state); Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); state .serialize(writer) .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}"))) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs = UdfAggState::deserialize(reader).map_err(|e| ErrorCode::Internal(e.to_string()))?; @@ -110,11 +112,11 @@ impl AggregateFunction for AggregateUdfScript { .runtime .merge(&states) .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; - place.write_state(state); + place.write(|| state); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); let states = arrow_select::concat::concat(&[&state.0, &other.0]) @@ -123,11 +125,11 @@ impl AggregateFunction for AggregateUdfScript { .runtime .merge(&states) .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge_states: {e}")))?; - place.write_state(state); + place.write(|| state); Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let state = place.get::(); let array = self .runtime @@ -142,7 +144,7 @@ impl AggregateFunction for AggregateUdfScript { true } - unsafe fn drop_state(&self, place: StateAddr) { + unsafe fn drop_state(&self, place: AggrState) { let state = place.get::(); std::ptr::drop_in_place(state); } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/utils.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/utils.rs deleted file mode 100644 index bc702f4dbd346..0000000000000 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/utils.rs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use databend_common_expression::types::binary::BinaryColumnBuilder; -use databend_common_functions::aggregates::AggregateFunctionRef; -use databend_common_hashtable::HashtableLike; - -#[inline] -pub fn estimated_key_size(table: &Table) -> usize { - table.unsize_key_size().unwrap_or_default() -} - -pub fn create_state_serializer(func: &AggregateFunctionRef, row: usize) -> BinaryColumnBuilder { - let size = func.serialize_size_per_row().unwrap_or(4); - BinaryColumnBuilder::with_capacity(row, row * size) -} diff --git a/src/query/service/src/pipelines/processors/transforms/window/window_function.rs b/src/query/service/src/pipelines/processors/transforms/window/window_function.rs index f2feb1690758e..53d9ffc65c673 100644 --- a/src/query/service/src/pipelines/processors/transforms/window/window_function.rs +++ b/src/query/service/src/pipelines/processors/transforms/window/window_function.rs @@ -16,16 +16,18 @@ use std::sync::Arc; use databend_common_base::runtime::drop_guard; use databend_common_exception::Result; +use databend_common_expression::get_states_layout; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::AggrState; +use databend_common_expression::AggrStateLoc; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::DataSchema; use databend_common_expression::InputColumns; -use databend_common_functions::aggregates::get_layout_offsets; +use databend_common_expression::StateAddr; use databend_common_functions::aggregates::AggregateFunction; use databend_common_functions::aggregates::AggregateFunctionFactory; -use databend_common_functions::aggregates::StateAddr; use databend_common_sql::executor::physical_plans::LagLeadDefault; use databend_common_sql::executor::physical_plans::WindowFunction; @@ -48,14 +50,15 @@ pub struct WindowFuncAggImpl { // Need to hold arena until `drop`. _arena: Arena, agg: Arc, - place: StateAddr, + addr: StateAddr, + loc: Box<[AggrStateLoc]>, args: Vec, } impl WindowFuncAggImpl { #[inline] pub fn reset(&self) { - self.agg.init_state(self.place); + self.agg.init_state(AggrState::new(self.addr, &self.loc)); } #[inline] @@ -65,12 +68,14 @@ impl WindowFuncAggImpl { #[inline] pub fn accumulate_row(&self, args: InputColumns, row: usize) -> Result<()> { - self.agg.accumulate_row(self.place, args, row) + self.agg + .accumulate_row(AggrState::new(self.addr, &self.loc), args, row) } #[inline] pub fn merge_result(&self, builder: &mut ColumnBuilder) -> Result<()> { - self.agg.merge_result(self.place, builder) + self.agg + .merge_result(AggrState::new(self.addr, &self.loc), builder) } } @@ -79,7 +84,7 @@ impl Drop for WindowFuncAggImpl { drop_guard(move || { if self.agg.need_manual_drop_state() { unsafe { - self.agg.drop_state(self.place); + self.agg.drop_state(AggrState::new(self.addr, &self.loc)); } } }) @@ -233,15 +238,15 @@ impl WindowFunctionImpl { Ok(match window { WindowFunctionInfo::Aggregate(agg, args) => { let arena = Arena::new(); - let mut state_offset = Vec::with_capacity(1); - let layout = get_layout_offsets(&[agg.clone()], &mut state_offset)?; - let place: StateAddr = arena.alloc_layout(layout).into(); - let place = place.next(state_offset[0]); + let mut states_layout = get_states_layout(&[agg.clone()])?; + let addr = arena.alloc_layout(states_layout.layout).into(); + let loc = states_layout.states_loc.pop().unwrap(); let agg = WindowFuncAggImpl { - _arena: arena, agg, - place, + addr, + loc, args, + _arena: arena, }; agg.reset(); Self::Aggregate(agg) diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs index 8f1d00ed51dc3..ae0b21b626656 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs @@ -277,19 +277,16 @@ impl PhysicalPlanBuilder { let keys = { let schema = aggregate_partial.output_schema()?; - let start = aggregate_partial.agg_funcs.len(); let end = schema.num_fields(); - let mut groups = Vec::with_capacity(end - start); - for idx in start..end { - let group_key = RemoteExpr::ColumnRef { + let start = end - aggregate_partial.group_by.len(); + (start..end) + .map(|id| RemoteExpr::ColumnRef { span: None, - id: idx, - data_type: schema.field(idx).data_type().clone(), - display_name: (idx - start).to_string(), - }; - groups.push(group_key); - } - groups + id, + data_type: schema.field(id).data_type().clone(), + display_name: (id - start).to_string(), + }) + .collect() }; PhysicalPlan::Exchange(Exchange { diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index 805d83af9bbb9..a8a73071aaa57 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -47,27 +47,20 @@ impl AggregatePartial { let input_schema = self.input.output_schema()?; let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.len()); - for agg in self.agg_funcs.iter() { - fields.push(DataField::new( - &agg.output_column.to_string(), - DataType::Binary, - )); - } - let group_types = self - .group_by - .iter() - .map(|index| { - Ok(input_schema - .field_with_name(&index.to_string())? - .data_type() - .clone()) - }) - .collect::>>()?; + fields.extend(self.agg_funcs.iter().map(|func| { + let name = func.output_column.to_string(); + DataField::new(&name, DataType::Binary) + })); - for (idx, data_type) in self.group_by.iter().zip(group_types.iter()) { - fields.push(DataField::new(&idx.to_string(), data_type.clone())); + for (idx, field) in self.group_by.iter().zip( + self.group_by + .iter() + .map(|index| input_schema.field_with_name(&index.to_string())), + ) { + fields.push(DataField::new(&idx.to_string(), field?.data_type().clone())); } + Ok(DataSchemaRefExt::create(fields)) } } diff --git a/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test b/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test index 40fa9b8026ca6..3d80615304da2 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test +++ b/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test @@ -38,3 +38,10 @@ query R select a + b from ( select weighted_avg(number+1, number*2) a, avg(number) b from numbers(10) ); ---- 11.833333492279053 + +query IR +select number % 3, weighted_avg(number, 1) from numbers(10) group by 1 order by 1; +---- +0 4.5 +1 4.0 +2 5.0 diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test index bdefed62d81e4..bd2c249067d25 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test @@ -1,4 +1,4 @@ -query T +query IT select length(max_state(number)), typeof(max_state(number)) from numbers(100); ---- 10 BINARY