Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeSet;
use std::collections::HashMap;
use std::sync::Arc;

use databend_common_exception::Result;
Expand All @@ -22,9 +24,14 @@ use crate::optimizer::ir::SExpr;
use crate::optimizer::optimizers::rule::Rule;
use crate::optimizer::optimizers::rule::RuleID;
use crate::optimizer::optimizers::rule::TransformResult;
use crate::plans::walk_expr_mut;
use crate::plans::EvalScalar;
use crate::plans::RelOp;
use crate::plans::ScalarItem;
use crate::plans::VisitorMut;
use crate::ColumnSet;
use crate::IndexType;
use crate::ScalarExpr;

// Merge two adjacent `EvalScalar`s into one
pub struct RuleMergeEvalScalar {
Expand Down Expand Up @@ -60,25 +67,19 @@ impl Rule for RuleMergeEvalScalar {
fn apply(&self, s_expr: &SExpr, state: &mut TransformResult) -> Result<()> {
let up_eval_scalar: EvalScalar = s_expr.plan().clone().try_into()?;
let down_eval_scalar: EvalScalar = s_expr.child(0)?.plan().clone().try_into()?;

let mut used_columns = ColumnSet::new();
for item in up_eval_scalar.items.iter() {
used_columns = used_columns
.union(&item.scalar.used_columns())
.cloned()
.collect();
}
let merged_items = Self::merge_items(up_eval_scalar, down_eval_scalar, &mut used_columns)?;

let rel_expr = RelExpr::with_s_expr(s_expr.child(0)?);
let input_prop = rel_expr.derive_relational_prop_child(0)?;
// Check if the up EvalScalar depends on the down EvalScalar

// Check that all used columns are available
if used_columns.is_subset(&input_prop.output_columns) {
// TODO(leiysky): eliminate duplicated scalars
let items = up_eval_scalar
.items
.into_iter()
.chain(down_eval_scalar.items)
.collect();
let merged = EvalScalar { items };
let merged = EvalScalar {
items: merged_items,
};

let new_expr = SExpr::create_unary(
Arc::new(merged.into()),
Expand All @@ -100,3 +101,60 @@ impl Default for RuleMergeEvalScalar {
Self::new()
}
}

impl RuleMergeEvalScalar {
fn merge_items(
up_eval_scalar: EvalScalar,
down_eval_scalar: EvalScalar,
used_columns: &mut BTreeSet<IndexType>,
) -> Result<Vec<ScalarItem>> {
let mut replace_set = HashMap::with_capacity(down_eval_scalar.items.len());

for item in &down_eval_scalar.items {
replace_set.insert(item.index, item.scalar.clone());
}

struct ReplaceColumnVisitor {
replace_set: HashMap<IndexType, ScalarExpr>,
}

impl VisitorMut<'_> for ReplaceColumnVisitor {
fn visit(&mut self, expr: &'_ mut ScalarExpr) -> Result<()> {
if let ScalarExpr::BoundColumnRef(column_ref) = expr {
if let Some(v) = self.replace_set.get(&column_ref.column.index) {
*expr = v.clone();
}

return Ok(());
}

walk_expr_mut(self, expr)
}
}

let mut visitor = ReplaceColumnVisitor { replace_set };

let mut new_items = down_eval_scalar.items;
for mut item in up_eval_scalar.items {
// Skip #X AS #X
if let ScalarExpr::BoundColumnRef(column_ref) = &item.scalar {
if column_ref.column.index == item.index
&& visitor.replace_set.contains_key(&item.index)
{
continue;
}
}

visitor.visit(&mut item.scalar)?;

*used_columns = used_columns
.union(&item.scalar.used_columns())
.cloned()
.collect();

new_items.push(item);
}

Ok(new_items)
}
}
Loading