Skip to content

Commit 5eb78bf

Browse files
committed
fix
1 parent 564a034 commit 5eb78bf

File tree

6 files changed

+150
-64
lines changed

6 files changed

+150
-64
lines changed

src/query/expression/src/aggregate/aggregate_function.rs

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use super::StateAddr;
2525
use crate::types::AnyPairType;
2626
use crate::types::AnyQuaternaryType;
2727
use crate::types::AnyTernaryType;
28+
use crate::types::AnyType;
2829
use crate::types::AnyUnaryType;
2930
use crate::types::DataType;
3031
use crate::BlockEntry;
@@ -91,45 +92,90 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {
9192
places: &[StateAddr],
9293
loc: &[AggrStateLoc],
9394
state: &BlockEntry,
95+
filter: Option<&Bitmap>,
9496
) -> Result<()> {
9597
match state.data_type().as_tuple().unwrap().len() {
9698
1 => {
9799
let view = state.downcast::<AnyUnaryType>().unwrap();
98-
for (place, data) in places.iter().zip(view.iter()) {
99-
self.merge(AggrState::new(*place, loc), &[data])?;
100+
let iter = places.iter().zip(view.iter());
101+
if let Some(filter) = filter {
102+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
103+
{
104+
self.merge(AggrState::new(*place, loc), &[data])?;
105+
}
106+
} else {
107+
for (place, data) in iter {
108+
self.merge(AggrState::new(*place, loc), &[data])?;
109+
}
100110
}
101111
}
102112
2 => {
103113
let view = state.downcast::<AnyPairType>().unwrap();
104-
for (place, data) in places.iter().zip(view.iter()) {
105-
self.merge(AggrState::new(*place, loc), &[data.0, data.1])?;
114+
let iter = places.iter().zip(view.iter());
115+
if let Some(filter) = filter {
116+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
117+
{
118+
self.merge(AggrState::new(*place, loc), &[data.0, data.1])?;
119+
}
120+
} else {
121+
for (place, data) in iter {
122+
self.merge(AggrState::new(*place, loc), &[data.0, data.1])?;
123+
}
106124
}
107125
}
108126
3 => {
109127
let view = state.downcast::<AnyTernaryType>().unwrap();
110-
for (place, data) in places.iter().zip(view.iter()) {
111-
self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?;
128+
let iter = places.iter().zip(view.iter());
129+
if let Some(filter) = filter {
130+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
131+
{
132+
self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?;
133+
}
134+
} else {
135+
for (place, data) in iter {
136+
self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?;
137+
}
112138
}
113139
}
114140
4 => {
115141
let view = state.downcast::<AnyQuaternaryType>().unwrap();
116-
for (place, data) in places.iter().zip(view.iter()) {
117-
self.merge(AggrState::new(*place, loc), &[
118-
data.0, data.1, data.2, data.3,
119-
])?;
142+
let iter = places.iter().zip(view.iter());
143+
if let Some(filter) = filter {
144+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
145+
{
146+
self.merge(AggrState::new(*place, loc), &[
147+
data.0, data.1, data.2, data.3,
148+
])?;
149+
}
150+
} else {
151+
for (place, data) in iter {
152+
self.merge(AggrState::new(*place, loc), &[
153+
data.0, data.1, data.2, data.3,
154+
])?;
155+
}
120156
}
121157
}
122158
_ => {
123-
let state = state.to_column();
124-
for (place, data) in places.iter().zip(state.iter()) {
125-
self.merge(
126-
AggrState::new(*place, loc),
127-
data.as_tuple().unwrap().as_slice(),
128-
)?;
159+
let view = state.downcast::<AnyType>().unwrap();
160+
let iter = places.iter().zip(view.iter());
161+
if let Some(filter) = filter {
162+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
163+
{
164+
self.merge(
165+
AggrState::new(*place, loc),
166+
data.as_tuple().unwrap().as_slice(),
167+
)?;
168+
}
169+
} else {
170+
for (place, data) in iter {
171+
self.merge(
172+
AggrState::new(*place, loc),
173+
data.as_tuple().unwrap().as_slice(),
174+
)?;
175+
}
129176
}
130177
}
131178
}
132-
133179
Ok(())
134180
}
135181

src/query/expression/src/aggregate/aggregate_hashtable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ impl AggregateHashTable {
218218
.zip(agg_states.iter())
219219
.zip(states_layout.states_loc.iter())
220220
{
221-
func.batch_merge(state_places, loc, state)?;
221+
func.batch_merge(state_places, loc, state, None)?;
222222
}
223223
}
224224
}

src/query/expression/src/types/tuple.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,57 +35,66 @@ pub type AnyQuaternaryType = QuaternaryType<AnyType, AnyType, AnyType, AnyType>;
3535
#[derive(Debug, Clone, PartialEq, Eq)]
3636
pub struct UnaryType<T>(PhantomData<T>);
3737

38-
impl<T> AccessType for UnaryType<T>
39-
where T: AccessType
38+
impl<A> AccessType for UnaryType<A>
39+
where A: AccessType
4040
{
41-
type Scalar = T::Scalar;
42-
type ScalarRef<'a> = T::ScalarRef<'a>;
43-
type Column = T::Column;
44-
type Domain = T::Domain;
45-
type ColumnIterator<'a> = T::ColumnIterator<'a>;
41+
type Scalar = A::Scalar;
42+
type ScalarRef<'a> = A::ScalarRef<'a>;
43+
type Column = A::Column;
44+
type Domain = A::Domain;
45+
type ColumnIterator<'a> = A::ColumnIterator<'a>;
4646

4747
fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar {
48-
T::to_owned_scalar(scalar)
48+
A::to_owned_scalar(scalar)
4949
}
5050

5151
fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> {
52-
T::to_scalar_ref(scalar)
52+
A::to_scalar_ref(scalar)
5353
}
5454

5555
fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option<Self::ScalarRef<'a>> {
56-
T::try_downcast_scalar(scalar)
56+
let [a] = scalar.as_tuple()?.as_slice() else {
57+
return None;
58+
};
59+
A::try_downcast_scalar(a)
5760
}
5861

5962
fn try_downcast_domain(domain: &Domain) -> Option<Self::Domain> {
60-
T::try_downcast_domain(domain)
63+
let [a] = domain.as_tuple()?.as_slice() else {
64+
return None;
65+
};
66+
A::try_downcast_domain(a)
6167
}
6268

6369
fn try_downcast_column(col: &Column) -> Option<Self::Column> {
64-
T::try_downcast_column(col)
70+
let [a] = col.as_tuple()?.as_slice() else {
71+
return None;
72+
};
73+
A::try_downcast_column(a)
6574
}
6675

6776
fn column_len(col: &Self::Column) -> usize {
68-
T::column_len(col)
77+
A::column_len(col)
6978
}
7079

7180
fn index_column(col: &Self::Column, index: usize) -> Option<Self::ScalarRef<'_>> {
72-
T::index_column(col, index)
81+
A::index_column(col, index)
7382
}
7483

7584
unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> {
76-
T::index_column_unchecked(col, index)
85+
A::index_column_unchecked(col, index)
7786
}
7887

7988
fn slice_column(col: &Self::Column, range: std::ops::Range<usize>) -> Self::Column {
80-
T::slice_column(col, range)
89+
A::slice_column(col, range)
8190
}
8291

8392
fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> {
84-
T::iter_column(col)
93+
A::iter_column(col)
8594
}
8695

8796
fn compare(lhs: Self::ScalarRef<'_>, rhs: Self::ScalarRef<'_>) -> Ordering {
88-
T::compare(lhs, rhs)
97+
A::compare(lhs, rhs)
8998
}
9099
}
91100

src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::fmt;
1616
use std::sync::Arc;
1717

1818
use databend_common_exception::Result;
19+
use databend_common_expression::types::AnyType;
1920
use databend_common_expression::types::Bitmap;
2021
use databend_common_expression::types::DataType;
2122
use databend_common_expression::types::NumberDataType;
@@ -320,8 +321,9 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction
320321
places: &[StateAddr],
321322
loc: &[AggrStateLoc],
322323
state: &BlockEntry,
324+
filter: Option<&Bitmap>,
323325
) -> Result<()> {
324-
self.0.batch_merge(places, loc, state)
326+
self.0.batch_merge(places, loc, state, filter)
325327
}
326328

327329
fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> {
@@ -546,38 +548,46 @@ impl<const NULLABLE_RESULT: bool> CommonNullAdaptor<NULLABLE_RESULT> {
546548
places: &[StateAddr],
547549
loc: &[AggrStateLoc],
548550
state: &BlockEntry,
551+
filter: Option<&Bitmap>,
549552
) -> Result<()> {
550553
if !NULLABLE_RESULT {
551-
return self.nested.batch_merge(places, loc, state);
554+
return self.nested.batch_merge(places, loc, state, filter);
552555
}
553556

554557
match state {
555558
BlockEntry::Column(Column::Tuple(tuple)) => {
556559
let nested_state = tuple[0..tuple.len() - 1].to_vec();
557560
let flag = tuple.last().unwrap().as_boolean().unwrap();
561+
let flag = match filter {
562+
Some(filter) => filter & flag,
563+
None => flag.clone(),
564+
};
565+
if flag.null_count() == 0 {
566+
return self.nested.batch_merge(
567+
places,
568+
loc,
569+
&Column::Tuple(nested_state).into(),
570+
filter,
571+
);
572+
}
558573

559-
let places = places
560-
.iter()
561-
.zip(flag.iter())
562-
.filter_map(|(place, flag)| {
563-
if flag {
564-
let addr = AggrState::new(*place, loc);
565-
if !get_flag(addr) {
566-
// initial the state to remove the dirty stats
567-
self.init_state(AggrState::new(*place, loc));
568-
}
569-
set_flag(addr, true);
574+
for (place, flag) in places.iter().zip(flag.iter()) {
575+
if flag {
576+
let addr = AggrState::new(*place, loc);
577+
if !get_flag(addr) {
578+
// initial the state to remove the dirty stats
579+
self.init_state(AggrState::new(*place, loc));
570580
}
571-
flag.then_some(*place)
572-
})
573-
.collect::<Vec<_>>();
581+
set_flag(addr, true);
582+
}
583+
}
574584

575-
let nested_state = Column::Tuple(nested_state).filter(flag).into();
585+
let nested_state = Column::Tuple(nested_state).into();
576586
self.nested
577-
.batch_merge(&places, &loc[..loc.len() - 1], &nested_state)
587+
.batch_merge(places, &loc[..loc.len() - 1], &nested_state, Some(&flag))
578588
}
579589
_ => {
580-
let state = state.to_column();
590+
let state = state.downcast::<AnyType>().unwrap();
581591
for (place, data) in places.iter().zip(state.iter()) {
582592
self.merge(
583593
AggrState::new(*place, loc),

src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,24 +216,44 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
216216
places: &[StateAddr],
217217
loc: &[AggrStateLoc],
218218
state: &BlockEntry,
219+
filter: Option<&Bitmap>,
219220
) -> Result<()> {
220221
match state {
221222
BlockEntry::Column(Column::Tuple(tuple)) => {
222223
let flag = tuple.last().unwrap().as_boolean().unwrap();
223-
for (place, flag) in places.iter().zip(flag.iter()) {
224-
merge_flag(AggrState::new(*place, loc), flag);
224+
let iter = places.iter().zip(flag.iter());
225+
if let Some(filter) = filter {
226+
for (place, flag) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
227+
{
228+
merge_flag(AggrState::new(*place, loc), flag);
229+
}
230+
} else {
231+
for (place, flag) in iter {
232+
merge_flag(AggrState::new(*place, loc), flag);
233+
}
225234
}
226235
let inner_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()).into();
227236
self.inner
228-
.batch_merge(places, &loc[0..loc.len() - 1], &inner_state)?;
237+
.batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?;
229238
}
230239
_ => {
231240
let state = state.to_column();
232-
for (place, data) in places.iter().zip(state.iter()) {
233-
self.merge(
234-
AggrState::new(*place, loc),
235-
data.as_tuple().unwrap().as_slice(),
236-
)?;
241+
let iter = places.iter().zip(state.iter());
242+
if let Some(filter) = filter {
243+
for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v))
244+
{
245+
self.merge(
246+
AggrState::new(*place, loc),
247+
data.as_tuple().unwrap().as_slice(),
248+
)?;
249+
}
250+
} else {
251+
for (place, data) in iter {
252+
self.merge(
253+
AggrState::new(*place, loc),
254+
data.as_tuple().unwrap().as_slice(),
255+
)?;
256+
}
237257
}
238258
}
239259
}

src/query/functions/src/aggregates/aggregate_combinator_state.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ impl AggregateFunction for AggregateStateCombinator {
126126
places: &[StateAddr],
127127
loc: &[AggrStateLoc],
128128
state: &BlockEntry,
129+
filter: Option<&Bitmap>,
129130
) -> Result<()> {
130-
self.nested.batch_merge(places, loc, state)
131+
self.nested.batch_merge(places, loc, state, filter)
131132
}
132133

133134
fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> {

0 commit comments

Comments
 (0)