1717package org .apache .spark .sql .rapids
1818
1919import ai .rapids .cudf
20+ import ai .rapids .cudf .Aggregation
2021import com .nvidia .spark .rapids ._
2122
2223import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
@@ -111,7 +112,7 @@ case class GpuAggregateExpression(origAggregateFunction: GpuAggregateFunction,
111112 resultId : ExprId )
112113 extends GpuExpression with GpuUnevaluable {
113114
114- val aggregateFunction = if (filter.isDefined) {
115+ val aggregateFunction : GpuAggregateFunction = if (filter.isDefined) {
115116 WrappedAggFunction (origAggregateFunction, filter.get)
116117 } else {
117118 origAggregateFunction
@@ -170,8 +171,8 @@ abstract case class CudfAggregate(ref: Expression) extends GpuUnevaluable {
170171 def getOrdinal (ref : Expression ): Int = ref.asInstanceOf [GpuBoundReference ].ordinal
171172 val updateReductionAggregate : cudf.ColumnVector => cudf.Scalar
172173 val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar
173- val updateAggregate : cudf. Aggregate
174- val mergeAggregate : cudf. Aggregate
174+ val updateAggregate : Aggregation
175+ val mergeAggregate : Aggregation
175176
176177 def dataType : DataType = ref.dataType
177178 def nullable : Boolean = ref.nullable
@@ -185,9 +186,8 @@ class CudfCount(ref: Expression) extends CudfAggregate(ref) {
185186 (col : cudf.ColumnVector ) => cudf.Scalar .fromLong(col.getRowCount - col.getNullCount)
186187 override val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar =
187188 (col : cudf.ColumnVector ) => col.sum
188- override lazy val updateAggregate : cudf.Aggregate =
189- cudf.Table .count(getOrdinal(ref), includeNulls)
190- override lazy val mergeAggregate : cudf.Aggregate = cudf.Table .sum(getOrdinal(ref))
189+ override lazy val updateAggregate : Aggregation = Aggregation .count(includeNulls)
190+ override lazy val mergeAggregate : Aggregation = Aggregation .sum()
191191 override def toString (): String = " CudfCount"
192192}
193193
@@ -196,8 +196,8 @@ class CudfSum(ref: Expression) extends CudfAggregate(ref) {
196196 (col : cudf.ColumnVector ) => col.sum
197197 override val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar =
198198 (col : cudf.ColumnVector ) => col.sum
199- override lazy val updateAggregate : cudf. Aggregate = cudf. Table . sum(getOrdinal(ref) )
200- override lazy val mergeAggregate : cudf. Aggregate = cudf. Table . sum(getOrdinal(ref) )
199+ override lazy val updateAggregate : Aggregation = Aggregation . sum()
200+ override lazy val mergeAggregate : Aggregation = Aggregation . sum()
201201 override def toString (): String = " CudfSum"
202202}
203203
@@ -206,8 +206,8 @@ class CudfMax(ref: Expression) extends CudfAggregate(ref) {
206206 (col : cudf.ColumnVector ) => col.max
207207 override val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar =
208208 (col : cudf.ColumnVector ) => col.max
209- override lazy val updateAggregate : cudf. Aggregate = cudf. Table . max(getOrdinal(ref) )
210- override lazy val mergeAggregate : cudf. Aggregate = cudf. Table . max(getOrdinal(ref) )
209+ override lazy val updateAggregate : Aggregation = Aggregation . max()
210+ override lazy val mergeAggregate : Aggregation = Aggregation . max()
211211 override def toString (): String = " CudfMax"
212212}
213213
@@ -216,48 +216,41 @@ class CudfMin(ref: Expression) extends CudfAggregate(ref) {
216216 (col : cudf.ColumnVector ) => col.min
217217 override val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar =
218218 (col : cudf.ColumnVector ) => col.min
219- override lazy val updateAggregate : cudf. Aggregate = cudf. Table . min(getOrdinal(ref) )
220- override lazy val mergeAggregate : cudf. Aggregate = cudf. Table . min(getOrdinal(ref) )
219+ override lazy val updateAggregate : Aggregation = Aggregation . min()
220+ override lazy val mergeAggregate : Aggregation = Aggregation . min()
221221 override def toString (): String = " CudfMin"
222222}
223223
224224abstract class CudfFirstLastBase (ref : Expression ) extends CudfAggregate (ref) {
225+ val includeNulls : Boolean
226+ val offset : Int
227+
225228 override val updateReductionAggregate : cudf.ColumnVector => cudf.Scalar =
226- (col : cudf.ColumnVector ) =>
227- throw new UnsupportedOperationException (" first/last reduction not supported on GPU" )
229+ (col : cudf.ColumnVector ) => col.reduce(Aggregation .nth(offset, includeNulls))
228230 override val mergeReductionAggregate : cudf.ColumnVector => cudf.Scalar =
229- (col : cudf.ColumnVector ) =>
230- throw new UnsupportedOperationException (" first/last reduction not supported on GPU" )
231+ (col : cudf.ColumnVector ) => col.reduce(Aggregation .nth(offset, includeNulls))
232+ override lazy val updateAggregate : Aggregation = Aggregation .nth(offset, includeNulls)
233+ override lazy val mergeAggregate : Aggregation = Aggregation .nth(offset, includeNulls)
231234}
232235
233236class CudfFirstIncludeNulls (ref : Expression ) extends CudfFirstLastBase (ref) {
234- val includeNulls = true
235- override lazy val updateAggregate : cudf.Aggregate =
236- cudf.Table .first(getOrdinal(ref), includeNulls)
237- override lazy val mergeAggregate : cudf.Aggregate =
238- cudf.Table .first(getOrdinal(ref), includeNulls)
237+ override val includeNulls : Boolean = true
238+ override val offset : Int = 0
239239}
240240
241241class CudfFirstExcludeNulls (ref : Expression ) extends CudfFirstLastBase (ref) {
242- val includeNulls = false
243- override lazy val updateAggregate : cudf.Aggregate =
244- cudf.Table .first(getOrdinal(ref), includeNulls)
245- override lazy val mergeAggregate : cudf.Aggregate =
246- cudf.Table .first(getOrdinal(ref), includeNulls)
242+ override val includeNulls : Boolean = false
243+ override val offset : Int = 0
247244}
248245
249246class CudfLastIncludeNulls (ref : Expression ) extends CudfFirstLastBase (ref) {
250- val includeNulls = true
251- override lazy val updateAggregate : cudf.Aggregate =
252- cudf.Table .last(getOrdinal(ref), includeNulls)
253- override lazy val mergeAggregate : cudf.Aggregate = cudf.Table .last(getOrdinal(ref), includeNulls)
247+ override val includeNulls : Boolean = true
248+ override val offset : Int = - 1
254249}
255250
256251class CudfLastExcludeNulls (ref : Expression ) extends CudfFirstLastBase (ref) {
257- val includeNulls = false
258- override lazy val updateAggregate : cudf.Aggregate =
259- cudf.Table .last(getOrdinal(ref), includeNulls)
260- override lazy val mergeAggregate : cudf.Aggregate = cudf.Table .last(getOrdinal(ref), includeNulls)
252+ override val includeNulls : Boolean = false
253+ override val offset : Int = - 1
261254}
262255
263256abstract class GpuDeclarativeAggregate extends GpuAggregateFunction with GpuUnevaluable {
0 commit comments