Skip to content

Commit 2fae68a

Browse files
committed
feat: do not count extra partition for custom queries with limits
1 parent 5b6e6f4 commit 2fae68a

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -568,15 +568,7 @@ class SchemaService(
568568
Seq(PartitionPagination.EMPTY)
569569
} else {
570570
val partitionSize = Math.ceil(count.toDouble / options.partitions).toInt
571-
val partitions = options.query.queryType match {
572-
case QueryType.QUERY => if (options.queryMetadata.queryCount.nonEmpty) {
573-
options.partitions // for custom query count we overfetch
574-
} else {
575-
options.partitions - 1
576-
}
577-
case _ => options.partitions - 1
578-
}
579-
(0 to partitions)
571+
(0 until options.partitions)
580572
.map(index => PartitionPagination(index, index * partitionSize, TopN(partitionSize, Array.empty)))
581573
}
582574
}

spark-3/src/test/scala/org/neo4j/spark/DataSourceReaderTSE.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,9 +998,9 @@ class DataSourceReaderTSE extends SparkConnectorScalaBaseTSE {
998998
)
999999
.load()
10001000

1001-
assertEquals(6, partitionedQueryCountDf.rdd.getNumPartitions)
1001+
assertEquals(5, partitionedQueryCountDf.rdd.getNumPartitions)
10021002
assertEquals(50, partitionedQueryCountDf.collect().map(_.getAs[String]("person")).toSet.size)
1003-
assertEquals(50, partitionedQueryCountDf.collect().map(_.getAs[String]("person")).size)
1003+
assertEquals(50, partitionedQueryCountDf.collect().map(_.getAs[String]("person")).length)
10041004

10051005
val partitionedQueryCountLiteralDf = ss.read.format(classOf[DataSource].getName)
10061006
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
@@ -1014,9 +1014,9 @@ class DataSourceReaderTSE extends SparkConnectorScalaBaseTSE {
10141014
.option("query.count", "50")
10151015
.load()
10161016

1017-
assertEquals(6, partitionedQueryCountLiteralDf.rdd.getNumPartitions)
1017+
assertEquals(5, partitionedQueryCountLiteralDf.rdd.getNumPartitions)
10181018
assertEquals(50, partitionedQueryCountLiteralDf.collect().map(_.getAs[String]("person")).toSet.size)
1019-
assertEquals(50, partitionedQueryCountLiteralDf.collect().map(_.getAs[String]("person")).size)
1019+
assertEquals(50, partitionedQueryCountLiteralDf.collect().map(_.getAs[String]("person")).length)
10201020
}
10211021

10221022
@Test

0 commit comments

Comments
 (0)