Skip to content

Commit 3daaf68

Browse files
committed
[SPARK-54172][SQL] Merge Into Schema Evolution should only add referenced columns
1 parent 6d5f111 commit 3daaf68

File tree

8 files changed

+649
-93
lines changed

8 files changed

+649
-93
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
16691669

16701670
case u: UpdateTable => resolveReferencesInUpdate(u)
16711671

1672-
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _)
1672+
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _, _)
16731673
if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution =>
16741674

16751675
EliminateSubqueryAliases(targetTable) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,26 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
3434
object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
3535

3636
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
37-
case m @ MergeIntoTable(_, _, _, _, _, _, _)
37+
case m @ MergeIntoTable(_, _, _, _, _, _, _, _)
3838
if m.needSchemaEvolution =>
3939
val newTarget = m.targetTable.transform {
40-
case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable)
40+
case r : DataSourceV2Relation => performSchemaEvolution(r, m)
4141
}
4242
m.copy(targetTable = newTarget)
4343
}
4444

45-
private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan)
45+
private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable)
4646
: DataSourceV2Relation = {
4747
(relation.catalog, relation.identifier) match {
4848
case (Some(c: TableCatalog), Some(i)) =>
49-
val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema)
49+
val referencedSourceSchema = MergeIntoTable.referencedSourceSchema(m)
50+
51+
val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema)
5052
c.alterTable(i, changes: _*)
5153
val newTable = c.loadTable(i)
5254
val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns())
5355
// Check if there are any remaining changes not applied.
54-
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema)
56+
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema)
5557
if (remainingChanges.nonEmpty) {
5658
throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError(
5759
remainingChanges, i.toQualifiedNameParts(c))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,18 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
6060
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions,
6161
coerceNestedTypes),
6262
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions,
63-
coerceNestedTypes))
63+
coerceNestedTypes),
64+
preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions)
65+
)
6466

6567
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned
6668
&& !m.needSchemaEvolution =>
67-
resolveAssignments(m)
69+
m.copy(
70+
matchedActions = m.notMatchedActions.map(resolveMergeAction),
71+
notMatchedActions = m.notMatchedActions.map(resolveMergeAction),
72+
notMatchedBySourceActions = m.matchedActions.map(resolveMergeAction),
73+
preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions)
74+
)
6875
}
6976

7077
private def validateStoreAssignmentPolicy(): Unit = {
@@ -83,33 +90,51 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
8390

8491
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
8592
p.transformExpressions {
86-
case assignment: Assignment =>
87-
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
88-
AssertNotNull(assignment.value)
89-
} else {
90-
assignment.value
91-
}
92-
val casted = if (assignment.key.dataType != nullHandled.dataType) {
93-
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
94-
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
95-
cast
96-
} else {
97-
nullHandled
98-
}
99-
val rawKeyType = assignment.key.transform {
100-
case a: AttributeReference =>
101-
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
102-
}.dataType
103-
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
104-
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
105-
} else {
106-
casted
107-
}
108-
val cleanedKey = assignment.key.transform {
109-
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
110-
}
111-
Assignment(cleanedKey, finalValue)
93+
case assignment: Assignment => resolveAssignment(assignment)
94+
}
95+
}
96+
97+
private def resolveMergeAction(mergeAction: MergeAction) = {
98+
mergeAction match {
99+
case u @ UpdateAction(_, assignments) =>
100+
u.copy(assignments = assignments.map(resolveAssignment))
101+
case i @ InsertAction(_, assignments) =>
102+
i.copy(assignments = assignments.map(resolveAssignment))
103+
case d: DeleteAction =>
104+
d
105+
case other =>
106+
throw new AnalysisException(
107+
errorClass = "_LEGACY_ERROR_TEMP_3053",
108+
messageParameters = Map("other" -> other.toString))
109+
}
110+
}
111+
112+
private def resolveAssignment(assignment: Assignment) = {
113+
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
114+
AssertNotNull(assignment.value)
115+
} else {
116+
assignment.value
117+
}
118+
val casted = if (assignment.key.dataType != nullHandled.dataType) {
119+
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
120+
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
121+
cast
122+
} else {
123+
nullHandled
124+
}
125+
val rawKeyType = assignment.key.transform {
126+
case a: AttributeReference =>
127+
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
128+
}.dataType
129+
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
130+
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
131+
} else {
132+
casted
133+
}
134+
val cleanedKey = assignment.key.transform {
135+
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
112136
}
137+
Assignment(cleanedKey, finalValue)
113138
}
114139

115140
private def alignActions(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
4545

4646
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
4747
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
48-
notMatchedBySourceActions, _) if m.resolved && m.rewritable && m.aligned &&
48+
notMatchedBySourceActions, _, _) if m.resolved && m.rewritable && m.aligned &&
4949
!m.needSchemaEvolution && matchedActions.isEmpty && notMatchedActions.size == 1 &&
5050
notMatchedBySourceActions.isEmpty =>
5151

@@ -79,7 +79,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
7979
}
8080

8181
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
82-
notMatchedBySourceActions, _)
82+
notMatchedBySourceActions, _, _)
8383
if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution &&
8484
matchedActions.isEmpty && notMatchedBySourceActions.isEmpty =>
8585

@@ -121,7 +121,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
121121
}
122122

123123
case m @ MergeIntoTable(aliasedTable, source, cond, matchedActions, notMatchedActions,
124-
notMatchedBySourceActions, _)
124+
notMatchedBySourceActions, _, _)
125125
if m.resolved && m.rewritable && m.aligned && !m.needSchemaEvolution =>
126126

127127
EliminateSubqueryAliases(aliasedTable) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException}
2121
import org.apache.spark.sql.AnalysisException
22-
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
22+
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
2424
import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage}
2525
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -860,7 +860,10 @@ case class MergeIntoTable(
860860
matchedActions: Seq[MergeAction],
861861
notMatchedActions: Seq[MergeAction],
862862
notMatchedBySourceActions: Seq[MergeAction],
863-
withSchemaEvolution: Boolean) extends BinaryCommand with SupportsSubquery {
863+
withSchemaEvolution: Boolean,
864+
// Preserves original pre-aligned actions for source matches
865+
preservedSourceActions: Option[Seq[MergeAction]] = None)
866+
extends BinaryCommand with SupportsSubquery {
864867

865868
lazy val aligned: Boolean = {
866869
val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
@@ -892,9 +895,13 @@ case class MergeIntoTable(
892895
case _ => false
893896
}
894897

895-
lazy val needSchemaEvolution: Boolean =
898+
private lazy val migrationSchema: StructType =
899+
MergeIntoTable.referencedSourceSchema(this)
900+
901+
lazy val needSchemaEvolution: Boolean = {
896902
schemaEvolutionEnabled &&
897-
MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty
903+
MergeIntoTable.schemaChanges(targetTable.schema, migrationSchema).nonEmpty
904+
}
898905

899906
private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
900907
EliminateSubqueryAliases(targetTable) match {
@@ -948,11 +955,12 @@ object MergeIntoTable {
948955
case currentField: StructField if newFieldMap.contains(currentField.name) =>
949956
schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType,
950957
originalTarget, originalSource, fieldPath ++ Seq(currentField.name))
951-
}}.flatten
958+
}
959+
}.flatten
952960

953961
// Identify the newly added fields and append to the end
954962
val currentFieldMap = toFieldMap(currentFields)
955-
val adds = newFields.filterNot (f => currentFieldMap.contains (f.name))
963+
val adds = newFields.filterNot(f => currentFieldMap.contains(f.name))
956964
.map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType))
957965

958966
updates ++ adds
@@ -990,8 +998,81 @@ object MergeIntoTable {
990998
CaseInsensitiveMap(fieldMap)
991999
}
9921000
}
1001+
1002+
// Filter the source schema to retain only fields that are referenced
1003+
// by at least one merge action
1004+
def referencedSourceSchema(merge: MergeIntoTable): StructType = {
1005+
1006+
val actions = merge.preservedSourceActions match {
1007+
case Some(preserved) => preserved
1008+
case None => merge.matchedActions ++ merge.notMatchedActions
1009+
}
1010+
1011+
val assignments = actions.collect {
1012+
case a: UpdateAction => a.assignments.map(_.key)
1013+
case a: InsertAction => a.assignments.map(_.key)
1014+
}.flatten
1015+
1016+
val containsStarAction = actions.exists {
1017+
case _: UpdateStarAction => true
1018+
case _: InsertStarAction => true
1019+
case _ => false
1020+
}
1021+
1022+
def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
1023+
StructType(sourceSchema.flatMap { field =>
1024+
val fieldPath = basePath :+ field.name
1025+
1026+
field.dataType match {
1027+
// Specifically assigned to in one clause:
1028+
// always keep, including all nested attributes
1029+
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
1030+
// If this is a struct and one of the children is being assigned to in a merge clause,
1031+
// keep it and continue filtering children.
1032+
case struct: StructType if assignments.exists(assign =>
1033+
isPrefix(fieldPath, extractFieldPath(assign))) =>
1034+
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
1035+
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
1036+
// clause. Check if it should be kept with any * action.
1037+
case struct: StructType if containsStarAction =>
1038+
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
1039+
case _ if containsStarAction => Some(field)
1040+
// The field and its children are not assigned to in any * or non-* action, drop it.
1041+
case _ => None
1042+
}
1043+
})
1044+
1045+
val sourceSchema = merge.sourceTable.schema
1046+
val targetSchema = merge.targetTable.schema
1047+
val res = filterSchema(merge.sourceTable.schema, Seq.empty)
1048+
res
1049+
}
1050+
1051+
// Helper method to extract field path from an Expression.
1052+
private def extractFieldPath(expr: Expression): Seq[String] = expr match {
1053+
case UnresolvedAttribute(nameParts) => nameParts
1054+
case a: AttributeReference => Seq(a.name)
1055+
case GetStructField(child, ordinal, nameOpt) =>
1056+
extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal")
1057+
case _ => Seq.empty
1058+
}
1059+
1060+
// Helper method to check if a given field path is a prefix of another path. Delegates
1061+
// equality to conf.resolver to correctly handle case sensitivity.
1062+
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
1063+
prefix.length <= path.length && prefix.zip(path).forall {
1064+
case (prefixNamePart, pathNamePart) =>
1065+
SQLConf.get.resolver(prefixNamePart, pathNamePart)
1066+
}
1067+
1068+
// Helper method to check if an assignment Expression's field path is equal to a path.
1069+
def isEqual(assignmentExpr: Expression, path: Seq[String]): Boolean = {
1070+
val exprPath = extractFieldPath(assignmentExpr)
1071+
exprPath.length == path.length && isPrefix(exprPath, path)
1072+
}
9931073
}
9941074

1075+
9951076
sealed abstract class MergeAction extends Expression with Unevaluable {
9961077
def condition: Option[Expression]
9971078
override def nullable: Boolean = false

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest {
167167
assert(optimized.resolved)
168168

169169
optimized match {
170-
case MergeIntoTable(_, _, s: InSubquery, _, _, _, _) =>
170+
case MergeIntoTable(_, _, s: InSubquery, _, _, _, _, _) =>
171171
val outerRefs = SubExprUtils.getOuterReferences(s.query.plan)
172172
assert(outerRefs.isEmpty, "should be no outer refs")
173173
case other =>

0 commit comments

Comments
 (0)