@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919
2020import org .apache .spark .{SparkIllegalArgumentException , SparkUnsupportedOperationException }
2121import org .apache .spark .sql .AnalysisException
22- import org .apache .spark .sql .catalyst .analysis .{AnalysisContext , AssignmentUtils , EliminateSubqueryAliases , FieldName , NamedRelation , PartitionSpec , ResolvedIdentifier , ResolvedProcedure , TypeCheckResult , UnresolvedException , UnresolvedProcedure , ViewSchemaMode }
22+ import org .apache .spark .sql .catalyst .analysis .{AnalysisContext , AssignmentUtils , EliminateSubqueryAliases , FieldName , NamedRelation , PartitionSpec , ResolvedIdentifier , ResolvedProcedure , TypeCheckResult , UnresolvedAttribute , UnresolvedException , UnresolvedProcedure , ViewSchemaMode }
2323import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .{DataTypeMismatch , TypeCheckSuccess }
2424import org .apache .spark .sql .catalyst .catalog .{FunctionResource , RoutineLanguage }
2525import org .apache .spark .sql .catalyst .catalog .CatalogTypes .TablePartitionSpec
@@ -860,7 +860,10 @@ case class MergeIntoTable(
860860 matchedActions : Seq [MergeAction ],
861861 notMatchedActions : Seq [MergeAction ],
862862 notMatchedBySourceActions : Seq [MergeAction ],
863- withSchemaEvolution : Boolean ) extends BinaryCommand with SupportsSubquery {
863+ withSchemaEvolution : Boolean ,
864+ // Preserves original pre-aligned actions for source matches
865+ preservedSourceActions : Option [Seq [MergeAction ]] = None )
866+ extends BinaryCommand with SupportsSubquery {
864867
865868 lazy val aligned : Boolean = {
866869 val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
@@ -892,9 +895,13 @@ case class MergeIntoTable(
892895 case _ => false
893896 }
894897
895- lazy val needSchemaEvolution : Boolean =
898+ private lazy val migrationSchema : StructType =
899+ MergeIntoTable .referencedSourceSchema(this )
900+
901+ lazy val needSchemaEvolution : Boolean = {
896902 schemaEvolutionEnabled &&
897- MergeIntoTable .schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty
903+ MergeIntoTable .schemaChanges(targetTable.schema, migrationSchema).nonEmpty
904+ }
898905
899906 private def schemaEvolutionEnabled : Boolean = withSchemaEvolution && {
900907 EliminateSubqueryAliases (targetTable) match {
@@ -948,11 +955,12 @@ object MergeIntoTable {
948955 case currentField : StructField if newFieldMap.contains(currentField.name) =>
949956 schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType,
950957 originalTarget, originalSource, fieldPath ++ Seq (currentField.name))
951- }}.flatten
958+ }
959+ }.flatten
952960
953961 // Identify the newly added fields and append to the end
954962 val currentFieldMap = toFieldMap(currentFields)
955- val adds = newFields.filterNot (f => currentFieldMap.contains (f.name))
963+ val adds = newFields.filterNot(f => currentFieldMap.contains(f.name))
956964 .map(f => TableChange .addColumn(fieldPath ++ Set (f.name), f.dataType))
957965
958966 updates ++ adds
@@ -990,8 +998,81 @@ object MergeIntoTable {
990998 CaseInsensitiveMap (fieldMap)
991999 }
9921000 }
1001+
1002+ // Filter the source schema to retain only fields that are referenced
1003+ // by at least one merge action
1004+ def referencedSourceSchema (merge : MergeIntoTable ): StructType = {
1005+
1006+ val actions = merge.preservedSourceActions match {
1007+ case Some (preserved) => preserved
1008+ case None => merge.matchedActions ++ merge.notMatchedActions
1009+ }
1010+
1011+ val assignments = actions.collect {
1012+ case a : UpdateAction => a.assignments.map(_.key)
1013+ case a : InsertAction => a.assignments.map(_.key)
1014+ }.flatten
1015+
1016+ val containsStarAction = actions.exists {
1017+ case _ : UpdateStarAction => true
1018+ case _ : InsertStarAction => true
1019+ case _ => false
1020+ }
1021+
1022+ def filterSchema (sourceSchema : StructType , basePath : Seq [String ]): StructType =
1023+ StructType (sourceSchema.flatMap { field =>
1024+ val fieldPath = basePath :+ field.name
1025+
1026+ field.dataType match {
1027+ // Specifically assigned to in one clause:
1028+ // always keep, including all nested attributes
1029+ case _ if assignments.exists(isEqual(_, fieldPath)) => Some (field)
1030+ // If this is a struct and one of the children is being assigned to in a merge clause,
1031+ // keep it and continue filtering children.
1032+ case struct : StructType if assignments.exists(assign =>
1033+ isPrefix(fieldPath, extractFieldPath(assign))) =>
1034+ Some (field.copy(dataType = filterSchema(struct, fieldPath)))
1035+ // The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
1036+ // clause. Check if it should be kept with any * action.
1037+ case struct : StructType if containsStarAction =>
1038+ Some (field.copy(dataType = filterSchema(struct, fieldPath)))
1039+ case _ if containsStarAction => Some (field)
1040+ // The field and its children are not assigned to in any * or non-* action, drop it.
1041+ case _ => None
1042+ }
1043+ })
1044+
1045+ val sourceSchema = merge.sourceTable.schema
1046+ val targetSchema = merge.targetTable.schema
1047+ val res = filterSchema(merge.sourceTable.schema, Seq .empty)
1048+ res
1049+ }
1050+
1051+ // Helper method to extract field path from an Expression.
1052+ private def extractFieldPath (expr : Expression ): Seq [String ] = expr match {
1053+ case UnresolvedAttribute (nameParts) => nameParts
1054+ case a : AttributeReference => Seq (a.name)
1055+ case GetStructField (child, ordinal, nameOpt) =>
1056+ extractFieldPath(child) :+ nameOpt.getOrElse(s " col $ordinal" )
1057+ case _ => Seq .empty
1058+ }
1059+
1060+ // Helper method to check if a given field path is a prefix of another path. Delegates
1061+ // equality to conf.resolver to correctly handle case sensitivity.
1062+ private def isPrefix (prefix : Seq [String ], path : Seq [String ]): Boolean =
1063+ prefix.length <= path.length && prefix.zip(path).forall {
1064+ case (prefixNamePart, pathNamePart) =>
1065+ SQLConf .get.resolver(prefixNamePart, pathNamePart)
1066+ }
1067+
1068+ // Helper method to check if an assignment Expression's field path is equal to a path.
1069+ def isEqual (assignmentExpr : Expression , path : Seq [String ]): Boolean = {
1070+ val exprPath = extractFieldPath(assignmentExpr)
1071+ exprPath.length == path.length && isPrefix(exprPath, path)
1072+ }
9931073}
9941074
1075+
9951076sealed abstract class MergeAction extends Expression with Unevaluable {
9961077 def condition : Option [Expression ]
9971078 override def nullable : Boolean = false
0 commit comments