Skip to content

Commit 1a2b17e

Browse files
authored
First/Last reduction and cleanup of agg APIs (#839)
Signed-off-by: Robert (Bobby) Evans <[email protected]>
1 parent ee8a778 commit 1a2b17e

File tree

5 files changed

+95
-97
lines changed

5 files changed

+95
-97
lines changed

integration_tests/src/main/python/hash_aggregate_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,51 @@ def test_count_distinct_with_nan_floats(data_gen):
348348

349349
# TODO: Literal tests
350350
# TODO: First and Last tests
351+
352+
# REDUCTIONS
353+
354+
non_nan_all_basic_gens = [byte_gen, short_gen, int_gen, long_gen,
355+
# nans and -0.0 cannot work because of nan support in min/max, -0.0 == 0.0 in cudf for distinct and
356+
# https://github.com/NVIDIA/spark-rapids/issues/84 in the ordering
357+
FloatGen(no_nans=True, special_cases=[]), DoubleGen(no_nans=True, special_cases=[]),
358+
string_gen, boolean_gen, date_gen, timestamp_gen]
359+
360+
361+
@pytest.mark.parametrize('data_gen', non_nan_all_basic_gens, ids=idfn)
362+
def test_generic_reductions(data_gen):
363+
assert_gpu_and_cpu_are_equal_collect(
364+
# Coalesce and sort are to make sure that first and last, which are non-deterministic
365+
# become deterministic
366+
lambda spark : binary_op_df(spark, data_gen)\
367+
.coalesce(1)\
368+
.sortWithinPartitions('b').selectExpr(
369+
'min(a)',
370+
'max(a)',
371+
'first(a)',
372+
'last(a)',
373+
'count(a)',
374+
'count(1)'),
375+
conf = _no_nans_float_conf)
376+
377+
@pytest.mark.parametrize('data_gen', non_nan_all_basic_gens, ids=idfn)
378+
def test_distinct_count_reductions(data_gen):
379+
assert_gpu_and_cpu_are_equal_collect(
380+
lambda spark : binary_op_df(spark, data_gen).selectExpr(
381+
'count(DISTINCT a)'))
382+
383+
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/837')
384+
@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn)
385+
def test_distinct_float_count_reductions(data_gen):
386+
assert_gpu_and_cpu_are_equal_collect(
387+
lambda spark : binary_op_df(spark, data_gen).selectExpr(
388+
'count(DISTINCT a)'))
389+
390+
@approximate_float
391+
@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
392+
def test_arithmetic_reductions(data_gen):
393+
assert_gpu_and_cpu_are_equal_collect(
394+
lambda spark : unary_op_df(spark, data_gen).selectExpr(
395+
'sum(a)',
396+
'avg(a)'),
397+
conf = _no_nans_float_conf)
398+

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package com.nvidia.spark.rapids
1818

19-
import ai.rapids.cudf.{DType, Table, WindowAggregate, WindowOptions}
19+
import ai.rapids.cudf.{Aggregation, AggregationOverWindow, DType, Table, WindowOptions}
2020
import com.nvidia.spark.rapids.GpuOverrides.wrapExpr
2121

2222
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -243,11 +243,11 @@ object GpuWindowExpression {
243243
def getRowBasedWindowFrame(columnIndex : Int,
244244
aggExpression : Expression,
245245
windowSpec : GpuSpecifiedWindowFrame)
246-
: WindowAggregate = {
246+
: AggregationOverWindow = {
247247

248248
// FIXME: Currently, only negative or 0 values are supported.
249249
var lower = getBoundaryValue(windowSpec.lower)
250-
if(lower > 0) {
250+
if (lower > 0) {
251251
throw new IllegalStateException(
252252
s"Lower-bounds ahead of current row is not supported. Found $lower")
253253
}
@@ -272,32 +272,33 @@ object GpuWindowExpression {
272272
val windowOption = WindowOptions.builder().minPeriods(1)
273273
.window(lower, upper).build()
274274

275-
aggExpression match {
275+
val agg: Aggregation = aggExpression match {
276276
case gpuAggregateExpression : GpuAggregateExpression =>
277277
gpuAggregateExpression.aggregateFunction match {
278-
case GpuCount(_) => WindowAggregate.count(columnIndex, windowOption)
279-
case GpuSum(_) => WindowAggregate.sum(columnIndex, windowOption)
280-
case GpuMin(_) => WindowAggregate.min(columnIndex, windowOption)
281-
case GpuMax(_) => WindowAggregate.max(columnIndex, windowOption)
278+
case GpuCount(_) => Aggregation.count()
279+
case GpuSum(_) => Aggregation.sum()
280+
case GpuMin(_) => Aggregation.min()
281+
case GpuMax(_) => Aggregation.max()
282282
case anythingElse =>
283283
throw new UnsupportedOperationException(
284284
s"Unsupported aggregation: ${anythingElse.prettyName}")
285285
}
286286
case _: GpuRowNumber =>
287-
// ROW_NUMBER does not depend on input column values.
288-
WindowAggregate.row_number(0, windowOption)
287+
// ROW_NUMBER does not depend on input column values, but it still should be fine
288+
Aggregation.rowNumber()
289289
case anythingElse =>
290290
throw new UnsupportedOperationException(
291291
s"Unsupported window aggregation: ${anythingElse.prettyName}")
292292
}
293+
agg.onColumn(columnIndex).overWindow(windowOption)
293294
}
294295

295296
def getRangeBasedWindowFrame(aggColumnIndex : Int,
296297
timeColumnIndex : Int,
297298
aggExpression : Expression,
298299
windowSpec : GpuSpecifiedWindowFrame,
299300
timestampIsAscending : Boolean)
300-
: WindowAggregate = {
301+
: AggregationOverWindow = {
301302

302303
// FIXME: Currently, only negative or 0 values are supported.
303304
var lower = getBoundaryValue(windowSpec.lower)
@@ -332,12 +333,12 @@ object GpuWindowExpression {
332333

333334
val windowOption = windowOptionBuilder.build()
334335

335-
aggExpression match {
336+
val agg: Aggregation = aggExpression match {
336337
case gpuAggExpression : GpuAggregateExpression => gpuAggExpression.aggregateFunction match {
337-
case GpuCount(_) => WindowAggregate.count(aggColumnIndex, windowOption)
338-
case GpuSum(_) => WindowAggregate.sum(aggColumnIndex, windowOption)
339-
case GpuMin(_) => WindowAggregate.min(aggColumnIndex, windowOption)
340-
case GpuMax(_) => WindowAggregate.max(aggColumnIndex, windowOption)
338+
case GpuCount(_) => Aggregation.count()
339+
case GpuSum(_) => Aggregation.sum()
340+
case GpuMin(_) => Aggregation.min()
341+
case GpuMax(_) => Aggregation.max()
341342
case anythingElse =>
342343
throw new UnsupportedOperationException(
343344
s"Unsupported aggregation: ${anythingElse.prettyName}")
@@ -346,6 +347,7 @@ object GpuWindowExpression {
346347
throw new UnsupportedOperationException(
347348
s"Unsupported window aggregation: ${anythingElse.prettyName}")
348349
}
350+
agg.onColumn(aggColumnIndex).overWindow(windowOption)
349351
}
350352

351353
def getBoundaryValue(boundary : Expression) : Int = boundary match {

sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,6 @@ class GpuHashAggregateMeta(
9898
resultExpressions
9999

100100
override def tagPlanForGpu(): Unit = {
101-
if (agg.groupingExpressions.isEmpty) {
102-
// first/last reductions not supported yet
103-
if (agg.aggregateExpressions.exists(e => e.aggregateFunction.isInstanceOf[First] ||
104-
e.aggregateFunction.isInstanceOf[Last])) {
105-
willNotWorkOnGpu("First/Last reductions are not supported on GPU")
106-
}
107-
}
108101
if (agg.resultExpressions.isEmpty) {
109102
willNotWorkOnGpu("result expressions is empty")
110103
}
@@ -192,13 +185,6 @@ class GpuSortAggregateMeta(
192185
resultExpressions
193186

194187
override def tagPlanForGpu(): Unit = {
195-
if (agg.groupingExpressions.isEmpty) {
196-
// first/last reductions not supported yet
197-
if (agg.aggregateExpressions.exists(e => e.aggregateFunction.isInstanceOf[First] ||
198-
e.aggregateFunction.isInstanceOf[Last])) {
199-
willNotWorkOnGpu("First/Last reductions are not supported on GPU")
200-
}
201-
}
202188
if (GpuOverrides.isAnyStringLit(agg.groupingExpressions)) {
203189
willNotWorkOnGpu("string literal values are not supported in a hash aggregate")
204190
}
@@ -842,9 +828,9 @@ case class GpuHashAggregateExec(
842828
val aggregates = aggModeCudfAggregates.flatMap(_._2)
843829
val cudfAggregates = aggModeCudfAggregates.flatMap { case (mode, aggregates) =>
844830
if ((mode == Partial || mode == Complete) && !merge) {
845-
aggregates.map(a => a.updateAggregate)
831+
aggregates.map(a => a.updateAggregate.onColumn(a.getOrdinal(a.ref)))
846832
} else {
847-
aggregates.map(a => a.mergeAggregate)
833+
aggregates.map(a => a.mergeAggregate.onColumn(a.getOrdinal(a.ref)))
848834
}
849835
}
850836
tbl = new cudf.Table(toAggregateCvs.map(_.getBase): _*)

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.spark.sql.rapids
1818

1919
import ai.rapids.cudf
20+
import ai.rapids.cudf.Aggregation
2021
import com.nvidia.spark.rapids._
2122

2223
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -111,7 +112,7 @@ case class GpuAggregateExpression(origAggregateFunction: GpuAggregateFunction,
111112
resultId: ExprId)
112113
extends GpuExpression with GpuUnevaluable {
113114

114-
val aggregateFunction = if (filter.isDefined) {
115+
val aggregateFunction: GpuAggregateFunction = if (filter.isDefined) {
115116
WrappedAggFunction(origAggregateFunction, filter.get)
116117
} else {
117118
origAggregateFunction
@@ -170,8 +171,8 @@ abstract case class CudfAggregate(ref: Expression) extends GpuUnevaluable {
170171
def getOrdinal(ref: Expression): Int = ref.asInstanceOf[GpuBoundReference].ordinal
171172
val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar
172173
val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar
173-
val updateAggregate: cudf.Aggregate
174-
val mergeAggregate: cudf.Aggregate
174+
val updateAggregate: Aggregation
175+
val mergeAggregate: Aggregation
175176

176177
def dataType: DataType = ref.dataType
177178
def nullable: Boolean = ref.nullable
@@ -185,9 +186,8 @@ class CudfCount(ref: Expression) extends CudfAggregate(ref) {
185186
(col: cudf.ColumnVector) => cudf.Scalar.fromLong(col.getRowCount - col.getNullCount)
186187
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
187188
(col: cudf.ColumnVector) => col.sum
188-
override lazy val updateAggregate: cudf.Aggregate =
189-
cudf.Table.count(getOrdinal(ref), includeNulls)
190-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.sum(getOrdinal(ref))
189+
override lazy val updateAggregate: Aggregation = Aggregation.count(includeNulls)
190+
override lazy val mergeAggregate: Aggregation = Aggregation.sum()
191191
override def toString(): String = "CudfCount"
192192
}
193193

@@ -196,8 +196,8 @@ class CudfSum(ref: Expression) extends CudfAggregate(ref) {
196196
(col: cudf.ColumnVector) => col.sum
197197
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
198198
(col: cudf.ColumnVector) => col.sum
199-
override lazy val updateAggregate: cudf.Aggregate = cudf.Table.sum(getOrdinal(ref))
200-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.sum(getOrdinal(ref))
199+
override lazy val updateAggregate: Aggregation = Aggregation.sum()
200+
override lazy val mergeAggregate: Aggregation = Aggregation.sum()
201201
override def toString(): String = "CudfSum"
202202
}
203203

@@ -206,8 +206,8 @@ class CudfMax(ref: Expression) extends CudfAggregate(ref) {
206206
(col: cudf.ColumnVector) => col.max
207207
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
208208
(col: cudf.ColumnVector) => col.max
209-
override lazy val updateAggregate: cudf.Aggregate = cudf.Table.max(getOrdinal(ref))
210-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.max(getOrdinal(ref))
209+
override lazy val updateAggregate: Aggregation = Aggregation.max()
210+
override lazy val mergeAggregate: Aggregation = Aggregation.max()
211211
override def toString(): String = "CudfMax"
212212
}
213213

@@ -216,48 +216,41 @@ class CudfMin(ref: Expression) extends CudfAggregate(ref) {
216216
(col: cudf.ColumnVector) => col.min
217217
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
218218
(col: cudf.ColumnVector) => col.min
219-
override lazy val updateAggregate: cudf.Aggregate = cudf.Table.min(getOrdinal(ref))
220-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.min(getOrdinal(ref))
219+
override lazy val updateAggregate: Aggregation = Aggregation.min()
220+
override lazy val mergeAggregate: Aggregation = Aggregation.min()
221221
override def toString(): String = "CudfMin"
222222
}
223223

224224
abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) {
225+
val includeNulls: Boolean
226+
val offset: Int
227+
225228
override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar =
226-
(col: cudf.ColumnVector) =>
227-
throw new UnsupportedOperationException("first/last reduction not supported on GPU")
229+
(col: cudf.ColumnVector) => col.reduce(Aggregation.nth(offset, includeNulls))
228230
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
229-
(col: cudf.ColumnVector) =>
230-
throw new UnsupportedOperationException("first/last reduction not supported on GPU")
231+
(col: cudf.ColumnVector) => col.reduce(Aggregation.nth(offset, includeNulls))
232+
override lazy val updateAggregate: Aggregation = Aggregation.nth(offset, includeNulls)
233+
override lazy val mergeAggregate: Aggregation = Aggregation.nth(offset, includeNulls)
231234
}
232235

233236
class CudfFirstIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
234-
val includeNulls = true
235-
override lazy val updateAggregate: cudf.Aggregate =
236-
cudf.Table.first(getOrdinal(ref), includeNulls)
237-
override lazy val mergeAggregate: cudf.Aggregate =
238-
cudf.Table.first(getOrdinal(ref), includeNulls)
237+
override val includeNulls: Boolean = true
238+
override val offset: Int = 0
239239
}
240240

241241
class CudfFirstExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
242-
val includeNulls = false
243-
override lazy val updateAggregate: cudf.Aggregate =
244-
cudf.Table.first(getOrdinal(ref), includeNulls)
245-
override lazy val mergeAggregate: cudf.Aggregate =
246-
cudf.Table.first(getOrdinal(ref), includeNulls)
242+
override val includeNulls: Boolean = false
243+
override val offset: Int = 0
247244
}
248245

249246
class CudfLastIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
250-
val includeNulls = true
251-
override lazy val updateAggregate: cudf.Aggregate =
252-
cudf.Table.last(getOrdinal(ref), includeNulls)
253-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.last(getOrdinal(ref), includeNulls)
247+
override val includeNulls: Boolean = true
248+
override val offset: Int = -1
254249
}
255250

256251
class CudfLastExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
257-
val includeNulls = false
258-
override lazy val updateAggregate: cudf.Aggregate =
259-
cudf.Table.last(getOrdinal(ref), includeNulls)
260-
override lazy val mergeAggregate: cudf.Aggregate = cudf.Table.last(getOrdinal(ref), includeNulls)
252+
override val includeNulls: Boolean = false
253+
override val offset: Int = -1
261254
}
262255

263256
abstract class GpuDeclarativeAggregate extends GpuAggregateFunction with GpuUnevaluable {

tests/src/test/scala/com/nvidia/spark/rapids/HashAggregatesSuite.scala

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,6 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
9090
.agg(first(col("c0"), ignoreNulls = true), last(col("c0"), ignoreNulls = true))
9191
}
9292

93-
testExpectedExceptionStartsWith("test unsorted agg with first and last no grouping",
94-
classOf[IllegalArgumentException],
95-
"Part of the plan is not columnar", firstDf, repart = 2) {
96-
frame => frame
97-
.coalesce(1)
98-
.agg(first(col("c0"), ignoreNulls = true), last(col("c0"), ignoreNulls = true))
99-
}
100-
101-
testExpectedExceptionStartsWith("test sorted agg with first and last no grouping",
102-
classOf[IllegalArgumentException],
103-
"Part of the plan is not columnar", firstDf, repart = 2) {
104-
frame => frame
105-
.coalesce(1)
106-
.sort(col("c2").asc, col("c0").asc) // force deterministic use case
107-
.agg(first(col("c0"), ignoreNulls = true), last(col("c0"), ignoreNulls = true))
108-
}
109-
11093
IGNORE_ORDER_testSparkResultsAreEqualWithCapture(
11194
"nullable aggregate with not null filter",
11295
firstDf,
@@ -733,20 +716,6 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
733716
frame => frame.groupBy(col("more_longs") + col("longs")).agg(min("longs"))
734717
}
735718

736-
testExpectedExceptionStartsWith("first without grouping",
737-
classOf[IllegalArgumentException],
738-
"Part of the plan is not columnar",
739-
intCsvDf) {
740-
frame => frame.agg(first("ints", false))
741-
}
742-
743-
testExpectedExceptionStartsWith("last without grouping",
744-
classOf[IllegalArgumentException],
745-
"Part of the plan is not columnar",
746-
intCsvDf) {
747-
frame => frame.agg(first("ints", false))
748-
}
749-
750719
IGNORE_ORDER_testSparkResultsAreEqual("first ignoreNulls=false", intCsvDf) {
751720
frame => frame.groupBy(col("more_ints")).agg(first("ints", false))
752721
}

0 commit comments

Comments
 (0)