Skip to content

Commit b7fe665

Browse files
committed
fix
1 parent 020d6c9 commit b7fe665

File tree

9 files changed

+64
-36
lines changed

9 files changed

+64
-36
lines changed

src/query/expression/src/aggregate/aggregate_function.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {
3636

3737
fn init_state(&self, place: &AggrState);
3838

39-
fn is_state(&self) -> bool {
40-
false
41-
}
42-
4339
fn state_layout(&self) -> Layout;
4440

4541
fn register_state(&self, register: &mut AggrStateRegister) {
@@ -257,6 +253,10 @@ impl AggrStateRegister {
257253
pub fn commit(&mut self) {
258254
self.offsets.push(self.states.len());
259255
}
256+
257+
pub fn states(&self) -> &[AggrStateType] {
258+
&self.states
259+
}
260260
}
261261

262262
impl Default for AggrStateRegister {

src/query/expression/src/aggregate/aggregate_function_state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl From<StateAddr> for usize {
116116
}
117117
}
118118

119-
pub fn get_state_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout> {
119+
pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout> {
120120
let mut register = AggrStateRegister::new();
121121
for func in funcs {
122122
func.register_state(&mut register);

src/query/expression/src/aggregate/payload.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use strength_reduce::StrengthReducedU64;
2222

2323
use super::payload_row::rowformat_size;
2424
use super::payload_row::serialize_column_to_rowformat;
25-
use crate::get_state_layout;
25+
use crate::get_states_layout;
2626
use crate::read;
2727
use crate::store;
2828
use crate::types::DataType;
@@ -90,7 +90,7 @@ impl Payload {
9090
aggrs: Vec<AggregateFunctionRef>,
9191
) -> Self {
9292
let states_layout = if !aggrs.is_empty() {
93-
Some(get_state_layout(&aggrs).unwrap())
93+
Some(get_states_layout(&aggrs).unwrap())
9494
} else {
9595
None
9696
};

src/query/functions/src/aggregates/aggregate_combinator_state.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use databend_common_exception::Result;
2020
use databend_common_expression::types::Bitmap;
2121
use databend_common_expression::types::DataType;
2222
use databend_common_expression::AggrStateRegister;
23+
use databend_common_expression::AggrStateType;
2324
use databend_common_expression::ColumnBuilder;
2425
use databend_common_expression::InputColumns;
2526
use databend_common_expression::Scalar;
@@ -36,6 +37,7 @@ use crate::aggregates::AggregateFunctionRef;
3637
#[derive(Clone)]
3738
pub struct AggregateStateCombinator {
3839
name: String,
40+
data_type: DataType,
3941
nested: AggregateFunctionRef,
4042
}
4143

@@ -56,7 +58,25 @@ impl AggregateStateCombinator {
5658

5759
let nested = AggregateFunctionFactory::instance().get(nested_name, params, arguments)?;
5860

59-
Ok(Arc::new(AggregateStateCombinator { name, nested }))
61+
let mut register = AggrStateRegister::default();
62+
nested.register_state(&mut register);
63+
64+
let sub_types = register
65+
.states()
66+
.iter()
67+
.map(|typ| match typ {
68+
AggrStateType::Bool => DataType::Boolean,
69+
AggrStateType::Custom(_) => DataType::Binary,
70+
})
71+
.collect();
72+
73+
let data_type = DataType::Tuple(sub_types);
74+
75+
Ok(Arc::new(AggregateStateCombinator {
76+
name,
77+
data_type,
78+
nested,
79+
}))
6080
}
6181

6282
pub fn combinator_desc() -> CombinatorDescription {
@@ -70,17 +90,13 @@ impl AggregateFunction for AggregateStateCombinator {
7090
}
7191

7292
fn return_type(&self) -> Result<DataType> {
73-
Ok(DataType::Binary)
93+
Ok(self.data_type.clone())
7494
}
7595

7696
fn init_state(&self, place: &AggrState) {
7797
self.nested.init_state(place);
7898
}
7999

80-
fn is_state(&self) -> bool {
81-
true
82-
}
83-
84100
fn state_layout(&self) -> Layout {
85101
unreachable!()
86102
}
@@ -146,12 +162,21 @@ impl AggregateFunction for AggregateStateCombinator {
146162
self.nested.merge_states(place, rhs)
147163
}
148164

149-
fn merge_result(&self, _place: &AggrState, _builder: &mut ColumnBuilder) -> Result<()> {
150-
todo!()
151-
// let str_builder = builder.as_binary_mut().unwrap();
152-
// self.serialize(place, &mut str_builder.data)?;
153-
// str_builder.commit_row();
154-
// Ok(())
165+
fn merge_result(&self, place: &AggrState, builder: &mut ColumnBuilder) -> Result<()> {
166+
let builders = builder.as_tuple_mut().unwrap();
167+
168+
let loc = place
169+
.loc()
170+
.iter()
171+
.enumerate()
172+
.map(|(i, loc)| match loc {
173+
AggrStateLoc::Bool(_, offset) => AggrStateLoc::Bool(i, *offset),
174+
AggrStateLoc::Custom(_, offset) => AggrStateLoc::Custom(i, *offset),
175+
})
176+
.collect::<Vec<_>>()
177+
.into_boxed_slice();
178+
let place = AggrState::with_loc(place.addr, loc);
179+
self.nested.serialize_builder(&place, builders)
155180
}
156181

157182
fn need_manual_drop_state(&self) -> bool {

src/query/functions/src/aggregates/aggregator_common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use databend_common_expression::Column;
2525
use databend_common_expression::ColumnBuilder;
2626
use databend_common_expression::Scalar;
2727

28-
use super::get_state_layout;
28+
use super::get_states_layout;
2929
use super::AggrState;
3030
use super::AggregateFunctionFactory;
3131
use super::AggregateFunctionRef;
@@ -117,7 +117,7 @@ struct EvalAggr {
117117
impl EvalAggr {
118118
fn new(func: AggregateFunctionRef) -> Self {
119119
let funcs = [func];
120-
let state_layout = get_state_layout(&funcs).unwrap();
120+
let state_layout = get_states_layout(&funcs).unwrap();
121121

122122
let _arena = Bump::new();
123123
let place = _arena.alloc_layout(state_layout.layout);

src/query/functions/tests/it/aggregates/agg_hashtable.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use std::sync::Arc;
3131

3232
use bumpalo::Bump;
3333
use databend_common_expression::block_debug::assert_block_value_sort_eq;
34+
use databend_common_expression::get_states_layout;
3435
use databend_common_expression::types::ArgType;
3536
use databend_common_expression::types::BooleanType;
3637
use databend_common_expression::types::DataType;
@@ -193,9 +194,11 @@ fn test_layout() {
193194
type S = DecimalSumState<false, DecimalType<i128>>;
194195
type M = DecimalSumState<false, DecimalType<I256>>;
195196

197+
let states_layout = get_states_layout(&[aggrs.clone()]).unwrap();
198+
196199
assert_eq!(
197-
aggrs.state_layout(),
198-
Layout::from_size_align(24, 8).unwrap()
200+
states_layout.layout,
201+
Layout::from_size_align(17, 8).unwrap()
199202
);
200203
assert_eq!(Layout::new::<S>(), Layout::from_size_align(16, 8).unwrap());
201204
assert_eq!(Layout::new::<M>(), Layout::from_size_align(32, 8).unwrap());

src/query/functions/tests/it/aggregates/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ use std::io::Write;
2020
use bumpalo::Bump;
2121
use comfy_table::Table;
2222
use databend_common_exception::Result;
23+
use databend_common_expression::get_states_layout;
2324
use databend_common_expression::type_check;
2425
use databend_common_expression::types::AnyType;
2526
use databend_common_expression::types::DataType;
2627
use databend_common_expression::AggrState;
27-
use databend_common_expression::AggrStateLoc;
2828
use databend_common_expression::BlockEntry;
2929
use databend_common_expression::Column;
3030
use databend_common_expression::ColumnBuilder;
@@ -192,23 +192,23 @@ pub fn simulate_two_groups_group_by(
192192

193193
let func = factory.get(name, params, arguments)?;
194194
let data_type = func.return_type()?;
195+
let states_layout = get_states_layout(&[func.clone()])?;
196+
let loc = states_layout.loc[0].clone();
195197

196198
let arena = Bump::new();
197199

198200
// init state for two groups
199-
let addr1 = arena.alloc_layout(func.state_layout()).into();
200-
let state1 = AggrState::new(addr1, 0);
201+
let addr1 = arena.alloc_layout(states_layout.layout.clone()).into();
202+
let state1 = AggrState::with_loc(addr1, loc.clone());
201203
func.init_state(&state1);
202-
let addr2 = arena.alloc_layout(func.state_layout()).into();
203-
let state2 = AggrState::new(addr2, 0);
204+
let addr2 = arena.alloc_layout(states_layout.layout.clone()).into();
205+
let state2 = AggrState::with_loc(addr2, loc.clone());
204206
func.init_state(&state2);
205207

206208
let places = (0..rows)
207209
.map(|i| if i % 2 == 0 { addr1 } else { addr2 })
208210
.collect::<Vec<_>>();
209211

210-
let loc = vec![AggrStateLoc::Custom(0, 0)].into_boxed_slice();
211-
212212
func.accumulate_keys(&places, loc, columns.into(), rows)?;
213213

214214
let mut builder = ColumnBuilder::with_capacity(&data_type, 1024);

src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use databend_common_expression::types::DataType;
1919
use databend_common_expression::ColumnBuilder;
2020
use databend_common_expression::DataBlock;
2121
use databend_common_expression::DataSchemaRef;
22-
use databend_common_functions::aggregates::get_state_layout;
22+
use databend_common_functions::aggregates::get_states_layout;
2323
use databend_common_functions::aggregates::AggregateFunctionRef;
2424
use databend_common_functions::aggregates::StatesLayout;
2525
use databend_common_sql::IndexType;
@@ -56,7 +56,7 @@ impl AggregatorParams {
5656
max_spill_io_requests: usize,
5757
) -> Result<Arc<AggregatorParams>> {
5858
let states_layout = if !agg_funcs.is_empty() {
59-
Some(get_state_layout(agg_funcs)?)
59+
Some(get_states_layout(agg_funcs)?)
6060
} else {
6161
None
6262
};

src/query/service/src/pipelines/processors/transforms/window/window_function.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::sync::Arc;
1616

1717
use databend_common_base::runtime::drop_guard;
1818
use databend_common_exception::Result;
19-
use databend_common_expression::get_state_layout;
19+
use databend_common_expression::get_states_layout;
2020
use databend_common_expression::types::DataType;
2121
use databend_common_expression::types::NumberDataType;
2222
use databend_common_expression::AggrState;
@@ -234,10 +234,10 @@ impl WindowFunctionImpl {
234234
WindowFunctionInfo::Aggregate(agg, args) => {
235235
let arena = Arena::new();
236236

237-
let state_layout = get_state_layout(&[agg.clone()])?;
237+
let states_layout = get_states_layout(&[agg.clone()])?;
238238
let place = AggrState::with_loc(
239-
arena.alloc_layout(state_layout.layout).into(),
240-
state_layout.loc[0].clone(),
239+
arena.alloc_layout(states_layout.layout).into(),
240+
states_layout.loc[0].clone(),
241241
);
242242
let agg = WindowFuncAggImpl {
243243
_arena: arena,

0 commit comments

Comments
 (0)