Skip to content

Commit 9310a9c

Browse files
committed
chore: make writers lazy init
1 parent ccfc8d0 commit 9310a9c

File tree

1 file changed

+66
-17
lines changed

1 file changed

+66
-17
lines changed

src/query/service/src/pipelines/processors/transforms/aggregator/new_aggregate/new_aggregate_spiller.rs

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ use crate::spillers::SpillsDataWriter;
4646

4747
struct PayloadWriter {
4848
path: String,
49-
// TODO: this may change to lazy init, for now it will create 128*thread_num files at most even
50-
// if the writer not used to write.
5149
writer: SpillsDataWriter,
5250
}
5351

@@ -109,7 +107,7 @@ impl WriteStats {
109107
struct AggregatePayloadWriters {
110108
spill_prefix: String,
111109
partition_count: usize,
112-
writers: Vec<PayloadWriter>,
110+
writers: Vec<Option<PayloadWriter>>,
113111
write_stats: WriteStats,
114112
ctx: Arc<QueryContext>,
115113
is_local: bool,
@@ -125,38 +123,40 @@ impl AggregatePayloadWriters {
125123
AggregatePayloadWriters {
126124
spill_prefix: prefix.to_string(),
127125
partition_count,
128-
writers: vec![],
126+
writers: Self::empty_writers(partition_count),
129127
write_stats: WriteStats::default(),
130128
ctx,
131129
is_local,
132130
}
133131
}
134132

135-
fn ensure_writers(&mut self) -> Result<()> {
136-
if self.writers.is_empty() {
137-
let mut writers = Vec::with_capacity(self.partition_count);
138-
for _ in 0..self.partition_count {
139-
writers.push(PayloadWriter::try_create(&self.spill_prefix)?);
140-
}
141-
self.writers = writers;
133+
fn empty_writers(partition_count: usize) -> Vec<Option<PayloadWriter>> {
134+
std::iter::repeat_with(|| None)
135+
.take(partition_count)
136+
.collect::<Vec<_>>()
137+
}
138+
139+
fn ensure_writer(&mut self, bucket: usize) -> Result<&mut PayloadWriter> {
140+
if self.writers[bucket].is_none() {
141+
self.writers[bucket] = Some(PayloadWriter::try_create(&self.spill_prefix)?);
142142
}
143-
Ok(())
143+
144+
Ok(self.writers[bucket].as_mut().unwrap())
144145
}
145146

146147
pub fn write_ready_blocks(&mut self, ready_blocks: Vec<(usize, DataBlock)>) -> Result<()> {
147148
if ready_blocks.is_empty() {
148149
return Ok(());
149150
}
150151

151-
self.ensure_writers()?;
152-
153152
for (bucket, block) in ready_blocks {
154153
if block.is_empty() {
155154
continue;
156155
}
157156

158157
let start = Instant::now();
159-
self.writers[bucket].write_block(block)?;
158+
let writer = self.ensure_writer(bucket)?;
159+
writer.write_block(block)?;
160160

161161
let elapsed = start.elapsed();
162162
self.write_stats.accumulate(elapsed);
@@ -166,13 +166,18 @@ impl AggregatePayloadWriters {
166166
}
167167

168168
pub fn finalize(&mut self) -> Result<Vec<NewSpilledPayload>> {
169-
let writers = mem::take(&mut self.writers);
170-
if writers.is_empty() {
169+
let writers = mem::replace(&mut self.writers, Self::empty_writers(self.partition_count));
170+
171+
if writers.iter().all(|writer| writer.is_none()) {
171172
return Ok(Vec::new());
172173
}
173174

174175
let mut spilled_payloads = Vec::new();
175176
for (partition_id, writer) in writers.into_iter().enumerate() {
177+
let Some(writer) = writer else {
178+
continue;
179+
};
180+
176181
let (path, written_size, row_groups) = writer.close()?;
177182

178183
if written_size != 0 {
@@ -404,3 +409,47 @@ fn flush_write_profile(ctx: &Arc<QueryContext>, stats: WriteStats) {
404409
ctx.get_aggregate_spill_progress().incr(&progress_val);
405410
}
406411
}
412+
413+
#[cfg(test)]
414+
mod tests {
415+
use std::collections::HashSet;
416+
417+
use databend_common_base::base::tokio;
418+
use databend_common_exception::Result;
419+
use databend_common_expression::types::Int32Type;
420+
use databend_common_expression::DataBlock;
421+
use databend_common_expression::FromData;
422+
423+
use crate::pipelines::processors::transforms::aggregator::new_aggregate::SharedPartitionStream;
424+
use crate::pipelines::processors::transforms::aggregator::NewAggregateSpiller;
425+
use crate::test_kits::TestFixture;
426+
427+
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
428+
async fn test_aggregate_payload_writers_lazy_init() -> Result<()> {
429+
let fixture = TestFixture::setup().await?;
430+
let ctx = fixture.new_query_ctx().await?;
431+
432+
let partition_count = 4;
433+
let partition_stream = SharedPartitionStream::new(1, 1024, 1024 * 1024, partition_count);
434+
let mut spiller =
435+
NewAggregateSpiller::try_create(ctx.clone(), partition_count, partition_stream, true)?;
436+
437+
let block = DataBlock::new_from_columns(vec![Int32Type::from_data(vec![1i32, 2, 3])]);
438+
439+
spiller.spill(0, block.clone())?;
440+
spiller.spill(2, block)?;
441+
442+
let payloads = spiller.spill_finish()?;
443+
444+
assert_eq!(payloads.len(), 2);
445+
446+
let spilled_files = ctx.get_spilled_files();
447+
assert_eq!(spilled_files.len(), 2);
448+
449+
let buckets: HashSet<_> = payloads.iter().map(|p| p.bucket).collect();
450+
assert!(buckets.contains(&0));
451+
assert!(buckets.contains(&2));
452+
453+
Ok(())
454+
}
455+
}

0 commit comments

Comments
 (0)