Skip to content

Commit ffa7bb7

Browse files
committed
performance improvements for the sql plan when it ran in real-time and getting updates
1 parent 508772d commit ffa7bb7

File tree

5 files changed

+210
-40
lines changed

5 files changed

+210
-40
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package io.dataflint.example
2+
3+
import org.apache.spark.sql.{DataFrame, SparkSession}
4+
import org.apache.spark.sql.functions._
5+
6+
object SqlPlanStressTestExample extends App {
7+
8+
val spark = SparkSession
9+
.builder()
10+
.appName("SQL Plan Stress Test Example")
11+
.config("spark.plugins", "io.dataflint.spark.SparkDataflintPlugin")
12+
.config("spark.ui.port", "10000")
13+
.config("spark.dataflint.telemetry.enabled", value = false)
14+
.config("spark.sql.maxMetadataStringLength", "10000")
15+
.config("spark.eventLog.enabled", "true")
16+
.master("local[*]")
17+
.getOrCreate()
18+
19+
import spark.implicits._
20+
21+
println("Starting SQL Plan Stress Test with 100 iterations...")
22+
23+
// Create a list to store all DataFrames for union
24+
val dataFrames = scala.collection.mutable.ListBuffer[DataFrame]()
25+
26+
// Loop 100 iterations to create datasets with different ranges
27+
for (i <- 1 to 100) {
28+
println(s"Processing iteration $i...")
29+
30+
// Create a small dataset with range (different range for each iteration)
31+
val rangeStart = i * 1000
32+
val rangeEnd = rangeStart + 100
33+
34+
val df = spark.range(rangeStart, rangeEnd)
35+
.toDF("id")
36+
.withColumn("iteration", lit(i))
37+
.withColumn("value", col("id") * 2)
38+
.withColumn("category", when(col("id") % 2 === 0, "even").otherwise("odd"))
39+
40+
// Apply select and filter operations
41+
val processedDf = df
42+
.select(
43+
col("id"),
44+
col("iteration"),
45+
col("value"),
46+
col("category")
47+
)
48+
.filter(col("id") % 10 =!= 0) // Filter out multiples of 10
49+
.filter(col("value") > rangeStart + 50) // Additional filter condition
50+
51+
// Add to the list for union
52+
dataFrames += processedDf
53+
}
54+
55+
println("Creating large union of all results...")
56+
57+
// Create a large union of all results
58+
val unionedDf = dataFrames.reduce(_.union(_))
59+
60+
println("Running count on the union...")
61+
62+
// Run count on the union
63+
val totalCount = unionedDf.count()
64+
65+
println(s"Total count after union: $totalCount")
66+
67+
scala.io.StdIn.readLine()
68+
spark.stop()
69+
}

spark-ui/src/components/SqlFlow/MetricDisplay.tsx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Box, Tooltip, Typography } from "@mui/material";
2-
import React from "react";
2+
import React, { memo } from "react";
33
import { TransperantTooltip } from "../AlertBadge/AlertBadge";
44
import { ConditionalWrapper } from "../InfoBox/InfoBox";
55
import styles from "./node-style.module.css";
@@ -16,7 +16,7 @@ interface MetricDisplayProps {
1616
metrics: MetricWithTooltip[];
1717
}
1818

19-
const MetricDisplay: React.FC<MetricDisplayProps> = ({ metrics }) => {
19+
const MetricDisplayComponent: React.FC<MetricDisplayProps> = ({ metrics }) => {
2020
const handleWheel = (e: React.WheelEvent) => {
2121
// Always prevent React Flow from handling wheel events in this area
2222
e.stopPropagation();
@@ -97,4 +97,7 @@ const MetricDisplay: React.FC<MetricDisplayProps> = ({ metrics }) => {
9797
);
9898
};
9999

100+
// Simple memoization - prevents unnecessary re-renders when metrics haven't changed
101+
const MetricDisplay = memo(MetricDisplayComponent);
102+
100103
export default MetricDisplay;

spark-ui/src/components/SqlFlow/SqlFlow.tsx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ const nodeTypes = { [StageNodeName]: StageNode };
3939
const SqlFlow: FC<{ sparkSQL: EnrichedSparkSQL }> = ({
4040
sparkSQL,
4141
}): JSX.Element => {
42+
// Get alerts for passing to nodes
43+
const alerts = useAppSelector((state) => state.spark.alerts);
44+
4245
const [instance, setInstance] = useState<ReactFlowInstance | undefined>();
4346
const [nodes, setNodes, onNodesChange] = useNodesState([]);
4447
const [edges, setEdges, onEdgesChange] = useEdgesState([]);
@@ -76,6 +79,7 @@ const SqlFlow: FC<{ sparkSQL: EnrichedSparkSQL }> = ({
7679
const { layoutNodes, layoutEdges } = SqlLayoutService.SqlElementsToLayout(
7780
sparkSQL,
7881
graphFilter,
82+
alerts,
7983
);
8084

8185
setNodes(layoutNodes);
@@ -97,6 +101,7 @@ const SqlFlow: FC<{ sparkSQL: EnrichedSparkSQL }> = ({
97101
const { layoutNodes, layoutEdges } = SqlLayoutService.SqlElementsToLayout(
98102
sparkSQL,
99103
graphFilter,
104+
alerts,
100105
);
101106

102107
setNodes(layoutNodes);

spark-ui/src/components/SqlFlow/SqlLayoutService.ts

Lines changed: 107 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,36 @@ import dagre from "dagre";
22
import { Edge, Node, Position } from "reactflow";
33
import { v4 as uuidv4 } from "uuid";
44
import {
5+
Alert,
6+
AlertsStore,
57
EnrichedSparkSQL,
68
EnrichedSqlEdge,
7-
EnrichedSqlNode,
89
GraphFilter,
10+
SQLAlertSourceData
911
} from "../../interfaces/AppStore";
1012
import { StageNodeName } from "./StageNode";
1113

14+
// Cache for layout results to avoid expensive recalculations
15+
const layoutCache = new Map<string, { layoutNodes: Node[]; layoutEdges: Edge[] }>();
16+
17+
// Cache for edge IDs to avoid regenerating UUIDs
18+
const edgeIdCache = new Map<string, string>();
19+
1220
const nodeWidth = 320;
1321
const nodeHeight = 320;
1422

1523
const getLayoutedElements = (
1624
nodes: Node[],
1725
edges: Edge[],
26+
cacheKey: string,
1827
): { layoutNodes: Node[]; layoutEdges: Edge[] } => {
28+
// Check cache first to avoid expensive Dagre layout calculation
29+
const cached = layoutCache.get(cacheKey);
30+
if (cached) {
31+
return cached;
32+
}
33+
34+
// Create new Dagre graph for layout calculation
1935
const dagreGraph = new dagre.graphlib.Graph();
2036
dagreGraph.setDefaultEdgeLabel(() => ({}));
2137
dagreGraph.setGraph({ rankdir: "LR" });
@@ -28,46 +44,108 @@ const getLayoutedElements = (
2844
dagreGraph.setEdge(edge.source, edge.target);
2945
});
3046

47+
// This is the expensive operation - only do it if not cached
3148
dagre.layout(dagreGraph);
3249

33-
nodes.forEach((node) => {
50+
// Create new nodes with positions (don't mutate input)
51+
const layoutNodes = nodes.map((node) => {
3452
const nodeWithPosition = dagreGraph.node(node.id);
35-
node.targetPosition = Position.Left;
36-
node.sourcePosition = Position.Right;
37-
38-
// We are shifting the dagre node position (anchor=center center) to the top left
39-
// so it matches the React Flow node anchor point (top left).
40-
node.position = {
41-
x: nodeWithPosition.x - nodeWidth / 2,
42-
y: nodeWithPosition.y - nodeHeight / 2,
53+
return {
54+
...node,
55+
targetPosition: Position.Left,
56+
sourcePosition: Position.Right,
57+
position: {
58+
x: nodeWithPosition.x - nodeWidth / 2,
59+
y: nodeWithPosition.y - nodeHeight / 2,
60+
},
4361
};
44-
45-
return node;
4662
});
4763

48-
return { layoutNodes: nodes, layoutEdges: edges };
64+
const result = { layoutNodes, layoutEdges: edges };
65+
66+
// Cache the result for future use
67+
layoutCache.set(cacheKey, result);
68+
69+
return result;
4970
};
5071

5172
class SqlLayoutService {
5273
static SqlElementsToLayout(
5374
sql: EnrichedSparkSQL,
5475
graphFilter: GraphFilter,
76+
alerts?: AlertsStore, // Alerts store for node-specific alerts
5577
): { layoutNodes: Node[]; layoutEdges: Edge[] } {
78+
// Helper function to find alert for a specific node
79+
const findNodeAlert = (nodeId: number): Alert | undefined => {
80+
return alerts?.alerts.find(
81+
(alert: Alert) => {
82+
// Type guard to ensure we're dealing with SQL alerts
83+
if (alert.source.type === "sql") {
84+
const sqlSource = alert.source as SQLAlertSourceData;
85+
return sqlSource.sqlNodeId === nodeId && sqlSource.sqlId === sql.id;
86+
}
87+
return false;
88+
}
89+
);
90+
};
91+
92+
// Create cache key based on SQL structure and filter
93+
const cacheKey = `${sql.uniqueId}-${graphFilter}`;
94+
95+
// Check if we have a cached result for this exact configuration
96+
const cached = layoutCache.get(cacheKey);
97+
if (cached) {
98+
// Update node data with current metrics and alerts while keeping cached layout
99+
const updatedNodes = cached.layoutNodes.map(node => ({
100+
...node,
101+
data: {
102+
...node.data,
103+
node: sql.nodes.find(n => n.nodeId.toString() === node.id) || node.data.node,
104+
sqlUniqueId: sql.uniqueId,
105+
sqlMetricUpdateId: sql.metricUpdateId,
106+
alert: findNodeAlert(parseInt(node.id)), // Add alert for this node
107+
}
108+
}));
109+
110+
return { layoutNodes: updatedNodes, layoutEdges: cached.layoutEdges };
111+
}
112+
56113
const { nodesIds, edges } = sql.filters[graphFilter];
57114

58-
const flowNodes: Node[] = sql.nodes
59-
.filter((node) => nodesIds.includes(node.nodeId))
60-
.map((node: EnrichedSqlNode) => {
115+
// Optimize node filtering and mapping
116+
const nodeMap = new Map(sql.nodes.map(node => [node.nodeId, node]));
117+
const flowNodes: Node[] = nodesIds
118+
.map((nodeId) => {
119+
const node = nodeMap.get(nodeId);
120+
if (!node) return null;
121+
61122
return {
62123
id: node.nodeId.toString(),
63-
data: { sqlId: sql.id, node: node },
124+
data: {
125+
sqlId: sql.id,
126+
node: node,
127+
sqlUniqueId: sql.uniqueId,
128+
sqlMetricUpdateId: sql.metricUpdateId,
129+
alert: findNodeAlert(node.nodeId), // Add alert for this node
130+
},
64131
type: StageNodeName,
65132
position: { x: 0, y: 0 },
66-
};
67-
});
133+
} as Node;
134+
})
135+
.filter((node): node is Node => node !== null);
136+
137+
// Optimize edge creation with cached IDs
68138
const flowEdges: Edge[] = edges.map((edge: EnrichedSqlEdge) => {
139+
const edgeKey = `${edge.fromId}-${edge.toId}`;
140+
let edgeId = edgeIdCache.get(edgeKey);
141+
142+
if (!edgeId) {
143+
edgeId = uuidv4();
144+
edgeIdCache.set(edgeKey, edgeId);
145+
}
146+
69147
return {
70-
id: uuidv4(),
148+
id: edgeId,
71149
source: edge.fromId.toString(),
72150
animated: true,
73151
target: edge.toId.toString(),
@@ -77,8 +155,16 @@ class SqlLayoutService {
77155
const { layoutNodes, layoutEdges } = getLayoutedElements(
78156
flowNodes,
79157
flowEdges,
158+
cacheKey,
80159
);
81-
return { layoutNodes: layoutNodes, layoutEdges: layoutEdges };
160+
161+
return { layoutNodes, layoutEdges };
162+
}
163+
164+
// Method to clear cache when needed (e.g., on app restart)
165+
static clearCache(): void {
166+
layoutCache.clear();
167+
edgeIdCache.clear();
82168
}
83169
}
84170

spark-ui/src/components/SqlFlow/StageNode.tsx

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import ErrorIcon from "@mui/icons-material/Error";
22
import WarningIcon from "@mui/icons-material/Warning";
33
import { Alert, AlertTitle, Box, Typography } from "@mui/material";
4-
import React, { FC, useMemo } from "react";
4+
import React, { FC, memo, useMemo } from "react";
55
import { useSearchParams } from "react-router-dom";
66
import { Handle, Position } from "reactflow";
7-
import { useAppSelector } from "../../Hooks";
8-
import { EnrichedSqlNode } from "../../interfaces/AppStore";
7+
import { Alert as AppAlert, EnrichedSqlNode } from "../../interfaces/AppStore";
98
import { TransperantTooltip } from "../AlertBadge/AlertBadge";
109
import MetricDisplay, { MetricWithTooltip } from "./MetricDisplay";
1110
import {
@@ -26,15 +25,20 @@ import PlanMetricsProcessor from "./PlanMetricsProcessor";
2625
export const StageNodeName: string = "stageNode";
2726

2827
interface StageNodeProps {
29-
data: { sqlId: string; node: EnrichedSqlNode };
28+
data: {
29+
sqlId: string;
30+
node: EnrichedSqlNode;
31+
sqlUniqueId?: string;
32+
sqlMetricUpdateId?: string;
33+
alert?: AppAlert; // Alert for this specific node
34+
};
3035
}
3136

32-
const StageNode: FC<StageNodeProps> = ({ data }) => {
37+
const StageNodeComponent: FC<StageNodeProps> = ({ data }) => {
3338
const [searchParams] = useSearchParams();
34-
const alerts = useAppSelector((state) => state.spark.alerts);
3539

3640
// Memoized computations for better performance
37-
const { isHighlighted, sqlNodeAlert, allMetrics } = useMemo(() => {
41+
const { isHighlighted, allMetrics } = useMemo(() => {
3842
// Parse nodeIds from URL parameters
3943
const nodeIdsParam = searchParams.get('nodeids');
4044
const highlightedNodeIds = nodeIdsParam
@@ -44,14 +48,6 @@ const StageNode: FC<StageNodeProps> = ({ data }) => {
4448
// Check if current node should be highlighted
4549
const highlighted = highlightedNodeIds.includes(data.node.nodeId);
4650

47-
// Find any alerts for this node
48-
const alert = alerts?.alerts.find(
49-
(alert) =>
50-
alert.source.type === "sql" &&
51-
alert.source.sqlNodeId === data.node.nodeId &&
52-
alert.source.sqlId === data.sqlId,
53-
);
54-
5551
// Process all metrics
5652
const metrics: MetricWithTooltip[] = [
5753
...processBaseMetrics(data.node),
@@ -70,10 +66,18 @@ const StageNode: FC<StageNodeProps> = ({ data }) => {
7066

7167
return {
7268
isHighlighted: highlighted,
73-
sqlNodeAlert: alert,
7469
allMetrics: metrics,
7570
};
76-
}, [data.node, data.sqlId, searchParams, alerts]);
71+
}, [
72+
// Use SQL identifiers for optimal memoization when available
73+
data.sqlUniqueId || data.node.nodeId,
74+
data.sqlMetricUpdateId || data.node.metrics,
75+
data.sqlId,
76+
searchParams
77+
]);
78+
79+
// Use the alert passed in data prop
80+
const sqlNodeAlert = data.alert;
7781

7882
const nodeClass = isHighlighted ? styles.nodeHighlighted : styles.node;
7983

@@ -169,4 +173,7 @@ const StageNode: FC<StageNodeProps> = ({ data }) => {
169173
);
170174
};
171175

176+
// Simple memoization - prevents unnecessary re-renders
177+
const StageNode = memo(StageNodeComponent);
178+
172179
export { StageNode };

0 commit comments

Comments
 (0)