Skip to content

Commit abbeb1e

Browse files
committed
Only allow schema evolution for case where new field is target of assignment where value is same name in source
1 parent 24b1a51 commit abbeb1e

File tree

7 files changed

+132
-149
lines changed

7 files changed

+132
-149
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
16691669

16701670
case u: UpdateTable => resolveReferencesInUpdate(u)
16711671

1672-
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _, _)
1672+
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _)
16731673
if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution =>
16741674

16751675
EliminateSubqueryAliases(targetTable) match {
@@ -1749,8 +1749,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17491749
m.copy(mergeCondition = resolvedMergeCondition,
17501750
matchedActions = newMatchedActions,
17511751
notMatchedActions = newNotMatchedActions,
1752-
notMatchedBySourceActions = newNotMatchedBySourceActions,
1753-
originalSourceActions = newMatchedActions ++ newNotMatchedActions)
1752+
notMatchedBySourceActions = newNotMatchedBySourceActions)
17541753
}
17551754

17561755
// UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
3434
object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
3535

3636
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
37-
case m @ MergeIntoTable(_, _, _, _, _, _, _, _)
37+
case m @ MergeIntoTable(_, _, _, _, _, _, _)
3838
if m.needSchemaEvolution =>
3939
val newTarget = m.targetTable.transform {
4040
case r : DataSourceV2Relation => performSchemaEvolution(r, m)
@@ -46,7 +46,7 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
4646
: DataSourceV2Relation = {
4747
(relation.catalog, relation.identifier) match {
4848
case (Some(c: TableCatalog), Some(i)) =>
49-
val referencedSourceSchema = MergeIntoTable.referencedSourceSchema(m)
49+
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m)
5050

5151
val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema)
5252
c.alterTable(i, changes: _*)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
4545

4646
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
4747
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
48-
notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned &&
48+
notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned &&
4949
!m.needSchemaEvolution && matchedActions.isEmpty && notMatchedActions.size == 1 &&
5050
notMatchedBySourceActions.isEmpty =>
5151

@@ -79,7 +79,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
7979
}
8080

8181
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
82-
notMatchedBySourceActions, _, _)
82+
notMatchedBySourceActions, _)
8383
if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution &&
8484
matchedActions.isEmpty && notMatchedBySourceActions.isEmpty =>
8585

@@ -121,7 +121,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
121121
}
122122

123123
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
124-
notMatchedBySourceActions, _, _)
124+
notMatchedBySourceActions, _)
125125
if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution =>
126126

127127
EliminateSubqueryAliases(aliasedTable) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -860,9 +860,7 @@ case class MergeIntoTable(
860860
matchedActions: Seq[MergeAction],
861861
notMatchedActions: Seq[MergeAction],
862862
notMatchedBySourceActions: Seq[MergeAction],
863-
withSchemaEvolution: Boolean,
864-
// Preserves original pre-aligned actions for source matches
865-
originalSourceActions: Seq[MergeAction])
863+
withSchemaEvolution: Boolean)
866864
extends BinaryCommand with SupportsSubquery {
867865

868866
lazy val aligned: Boolean = {
@@ -895,14 +893,12 @@ case class MergeIntoTable(
895893
case _ => false
896894
}
897895

898-
// a pruned version of source schema that only contains columns/nested fields
899-
// explicitly assigned by MERGE INTO actions
900-
private lazy val referencedSourceSchema: StructType =
901-
MergeIntoTable.referencedSourceSchema(this)
896+
private lazy val sourceSchemaForEvolution: StructType =
897+
MergeIntoTable.sourceSchemaForSchemaEvolution(this)
902898

903899
lazy val needSchemaEvolution: Boolean = {
904900
schemaEvolutionEnabled &&
905-
MergeIntoTable.schemaChanges(targetTable.schema, referencedSourceSchema).nonEmpty
901+
MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty
906902
}
907903

908904
private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
@@ -921,25 +917,6 @@ case class MergeIntoTable(
921917

922918
object MergeIntoTable {
923919

924-
def apply(
925-
targetTable: LogicalPlan,
926-
sourceTable: LogicalPlan,
927-
mergeCondition: Expression,
928-
matchedActions: Seq[MergeAction],
929-
notMatchedActions: Seq[MergeAction],
930-
notMatchedBySourceActions: Seq[MergeAction],
931-
withSchemaEvolution: Boolean): MergeIntoTable = {
932-
MergeIntoTable(
933-
targetTable,
934-
sourceTable,
935-
mergeCondition,
936-
matchedActions,
937-
notMatchedActions,
938-
notMatchedBySourceActions,
939-
withSchemaEvolution,
940-
matchedActions ++ notMatchedActions)
941-
}
942-
943920
def getWritePrivileges(
944921
matchedActions: Iterable[MergeAction],
945922
notMatchedActions: Iterable[MergeAction],
@@ -1020,16 +997,18 @@ object MergeIntoTable {
1020997
}
1021998
}
1022999

1023-
// Filter the source schema to retain only fields that are referenced
1024-
// by at least one merge action
1025-
def referencedSourceSchema(merge: MergeIntoTable): StructType = {
1000+
// A pruned version of source schema that only contains columns/nested fields
1001+
// explicitly and directly assigned to a target counterpart in MERGE INTO actions.
1002+
// New columns/nested fields not existing in target will be added for schema evolution.
1003+
def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {
10261004

1027-
val assignments = merge.originalSourceActions.collect {
1028-
case a: UpdateAction => a.assignments.map(_.key)
1029-
case a: InsertAction => a.assignments.map(_.key)
1005+
val actions = merge.matchedActions ++ merge.notMatchedActions
1006+
val assignments = actions.collect {
1007+
case a: UpdateAction => a.assignments
1008+
case a: InsertAction => a.assignments
10301009
}.flatten
10311010

1032-
val containsStarAction = merge.originalSourceActions.exists {
1011+
val containsStarAction = actions.exists {
10331012
case _: UpdateStarAction => true
10341013
case _: InsertStarAction => true
10351014
case _ => false
@@ -1046,7 +1025,7 @@ object MergeIntoTable {
10461025
// If this is a struct and one of the children is being assigned to in a merge clause,
10471026
// keep it and continue filtering children.
10481027
case struct: StructType if assignments.exists(assign =>
1049-
isPrefix(fieldPath, extractFieldPath(assign))) =>
1028+
isPrefix(fieldPath, extractFieldPath(assign.key))) =>
10501029
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
10511030
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
10521031
// clause. Check if it should be kept with any * action.
@@ -1058,8 +1037,7 @@ object MergeIntoTable {
10581037
}
10591038
})
10601039

1061-
val res = filterSchema(merge.sourceTable.schema, Seq.empty)
1062-
res
1040+
filterSchema(merge.sourceTable.schema, Seq.empty)
10631041
}
10641042

10651043
// Helper method to extract field path from an Expression.
@@ -1071,18 +1049,28 @@ object MergeIntoTable {
10711049
case _ => Seq.empty
10721050
}
10731051

1074-
// Helper method to check if a given field path is a prefix of another path. Delegates
1075-
// equality to conf.resolver to correctly handle case sensitivity.
1052+
// Helper method to check if a given field path is a prefix of another path.
10761053
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
10771054
prefix.length <= path.length && prefix.zip(path).forall {
10781055
case (prefixNamePart, pathNamePart) =>
10791056
SQLConf.get.resolver(prefixNamePart, pathNamePart)
10801057
}
10811058

1082-
// Helper method to check if an assignment Expression's field path is equal to a path.
1083-
def isEqual(assignmentExpr: Expression, path: Seq[String]): Boolean = {
1084-
val exprPath = extractFieldPath(assignmentExpr)
1085-
exprPath.length == path.length && isPrefix(exprPath, path)
1059+
// Helper method to check if a given field path is a suffix of another path.
1060+
private def isSuffix(prefix: Seq[String], path: Seq[String]): Boolean =
1061+
prefix.length <= path.length && prefix.reverse.zip(path.reverse).forall {
1062+
case (prefixNamePart, pathNamePart) =>
1063+
SQLConf.get.resolver(prefixNamePart, pathNamePart)
1064+
}
1065+
1066+
// Helper method to check if an assignment key is equal to a source column
1067+
// and if the assignment value is the corresponding source column directly
1068+
private def isEqual(assignment: Assignment, path: Seq[String]): Boolean = {
1069+
val assignmenKeyExpr = extractFieldPath(assignment.key)
1070+
val assignmentValueExpr = extractFieldPath(assignment.value)
1071+
// Valid assignments are: col = s.col or col.nestedField = s.col.nestedField
1072+
assignmenKeyExpr.length == path.length && isPrefix(assignmenKeyExpr, path) &&
1073+
isSuffix(path, assignmentValueExpr)
10861074
}
10871075
}
10881076

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest {
167167
assert(optimized.resolved)
168168

169169
optimized match {
170-
case MergeIntoTable(_, _, s: InSubquery, _, _, _, _, _) =>
170+
case MergeIntoTable(_, _, s: InSubquery, _, _, _, _) =>
171171
val outerRefs = SubExprUtils.getOuterReferences(s.query.plan)
172172
assert(outerRefs.isEmpty, "should be no outer refs")
173173
case other =>

0 commit comments

Comments
 (0)