Skip to content

Commit fcbc5ee

Browse files
committed
[GR-22670] Rewrite row sums and means to use VectorDataLibrary
(cherry picked from commit 874499a)
1 parent a01391b commit fcbc5ee

File tree

6 files changed

+85
-59
lines changed

6 files changed

+85
-59
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import com.oracle.truffle.r.runtime.data.RIntVector;
3232
import com.oracle.truffle.r.runtime.data.RLogicalVector;
3333
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
34+
import com.oracle.truffle.r.runtime.ops.na.NACheck;
3435

3536
//Implements .colMeans
3637
@RBuiltin(name = "colMeans", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE)
3738
public abstract class ColMeans extends ColSumsBase {
3839

40+
protected final NACheck na = NACheck.create();
3941
@Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation();
4042

4143
static {

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSums.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ protected RDoubleVector colSums(RIntVector x, int rowNum, int colNum, boolean rn
134134
final boolean rna = removeNA.profile(rnaParam);
135135
double[] result = new double[colNum];
136136
boolean isComplete = true;
137-
na.enable(x);
138137
int pos = 0;
139138
Object xData = x.getData();
140139
RandomAccessIterator xIt = xDataLib.randomAccessIterator(xData);

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSumsBase.java

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
2727
import static com.oracle.truffle.r.runtime.RError.Message.INVALID_ARGUMENT;
2828

29+
import com.oracle.truffle.api.dsl.Cached;
2930
import com.oracle.truffle.api.dsl.Specialization;
31+
import com.oracle.truffle.api.profiles.BranchProfile;
3032
import com.oracle.truffle.api.profiles.ConditionProfile;
3133
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
3234
import com.oracle.truffle.r.runtime.RError;
@@ -35,7 +37,6 @@
3537
import com.oracle.truffle.r.runtime.data.RDoubleVector;
3638
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3739
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
38-
import com.oracle.truffle.r.runtime.ops.na.NACheck;
3940

4041
/**
4142
* Base class that provides arguments handling and validation helper methods and trivial cases
@@ -44,7 +45,6 @@
4445
*/
4546
public abstract class ColSumsBase extends RBuiltinNode.Arg4 {
4647

47-
protected final NACheck na = NACheck.create();
4848
private final ConditionProfile vectorLengthProfile = ConditionProfile.createBinaryProfile();
4949

5050
protected static Casts createCasts(Class<? extends ColSumsBase> extCls) {
@@ -79,56 +79,61 @@ protected final RDoubleVector doScalarNaRmFalse(double x, int rowNum, int colNum
7979
}
8080

8181
@Specialization(guards = "naRm")
82-
protected final RDoubleVector doScalarNaRmTrue(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
82+
protected final RDoubleVector doScalarNaRmTrue(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
83+
@Cached BranchProfile naProfile) {
8384
checkLengthOne(rowNum, colNum);
84-
na.enable(x);
85-
if (!na.check(x) && !Double.isNaN(x)) {
85+
if (!RRuntime.isNA(x) && !Double.isNaN(x)) {
8686
return RDataFactory.createDoubleVectorFromScalar(x);
8787
} else {
88+
naProfile.enter();
8889
return RDataFactory.createDoubleVectorFromScalar(Double.NaN);
8990
}
9091
}
9192

9293
@Specialization(guards = "!naRm")
93-
protected final RDoubleVector doScalarNaRmFalse(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
94+
protected final RDoubleVector doScalarNaRmFalse(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
95+
@Cached BranchProfile naProfile) {
9496
checkLengthOne(rowNum, colNum);
95-
na.enable(x);
96-
if (!na.check(x)) {
97+
if (!RRuntime.isNA(x)) {
9798
return RDataFactory.createDoubleVectorFromScalar(x);
9899
} else {
100+
naProfile.enter();
99101
return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA);
100102
}
101103
}
102104

103105
@Specialization(guards = "naRm")
104-
protected final RDoubleVector doScalarNaRmTrue(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
106+
protected final RDoubleVector doScalarNaRmTrue(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
107+
@Cached BranchProfile naProfile) {
105108
checkLengthOne(rowNum, colNum);
106-
na.enable(x);
107-
if (!na.check(x)) {
109+
if (!RRuntime.isNA(x)) {
108110
return RDataFactory.createDoubleVectorFromScalar(x);
109111
} else {
112+
naProfile.enter();
110113
return RDataFactory.createDoubleVectorFromScalar(Double.NaN);
111114
}
112115
}
113116

114117
@Specialization(guards = "!naRm")
115-
protected final RDoubleVector doScalarNaRmFalse(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
118+
protected final RDoubleVector doScalarNaRmFalse(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
119+
@Cached BranchProfile naProfile) {
116120
checkLengthOne(rowNum, colNum);
117-
na.enable(x);
118-
if (!na.check(x)) {
121+
if (!RRuntime.isNA(x)) {
119122
return RDataFactory.createDoubleVectorFromScalar(x);
120123
} else {
124+
naProfile.enter();
121125
return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA);
122126
}
123127
}
124128

125129
@Specialization(guards = "naRm")
126-
protected final RDoubleVector doScalarNaRmTrue(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
130+
protected final RDoubleVector doScalarNaRmTrue(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
131+
@Cached BranchProfile naProfile) {
127132
checkLengthOne(rowNum, colNum);
128-
na.enable(x);
129-
if (!na.check(x)) {
133+
if (!RRuntime.isNA(x)) {
130134
return RDataFactory.createDoubleVectorFromScalar(x);
131135
} else {
136+
naProfile.enter();
132137
return RDataFactory.createDoubleVectorFromScalar(Double.NaN);
133138
}
134139
}

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
2525

2626
import com.oracle.truffle.api.dsl.Specialization;
27+
import com.oracle.truffle.api.library.CachedLibrary;
2728
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
2829
import com.oracle.truffle.r.runtime.data.RDoubleVector;
2930
import com.oracle.truffle.r.runtime.data.RIntVector;
3031
import com.oracle.truffle.r.runtime.data.RLogicalVector;
32+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3133

3234
// Implements .rowMeans
3335
@RBuiltin(name = "rowMeans", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE)
@@ -37,28 +39,36 @@ public abstract class RowMeans extends RowSumsBase {
3739
createCasts(RowMeans.class);
3840
}
3941

40-
@Specialization
41-
protected RDoubleVector rowMeans(RDoubleVector x, int rowNum, int colNum, boolean naRm) {
42-
return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> v.getDataAt(i));
42+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
43+
protected RDoubleVector rowMeans(RDoubleVector x, int rowNum, int colNum, boolean naRm,
44+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
45+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, TransformMean.INSTANCE);
4346
}
4447

45-
@Specialization
46-
protected RDoubleVector rowMeans(RIntVector x, int rowNum, int colNum, boolean naRm) {
47-
return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> nacheck.convertIntToDouble(v.getDataAt(i)));
48+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
49+
protected RDoubleVector rowMeans(RIntVector x, int rowNum, int colNum, boolean naRm,
50+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
51+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, TransformMean.INSTANCE);
4852
}
4953

50-
@Specialization
51-
protected RDoubleVector rowMeans(RLogicalVector x, int rowNum, int colNum, boolean naRm) {
52-
return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> nacheck.convertLogicalToDouble(v.getDataAt(i)));
54+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
55+
protected RDoubleVector rowMeans(RLogicalVector x, int rowNum, int colNum, boolean naRm,
56+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
57+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, TransformMean.INSTANCE);
5358
}
5459

55-
private static double getMean(double sum, int notNACount) {
56-
if (Double.isNaN(sum)) {
57-
return sum;
58-
} else if (notNACount == 0) {
59-
return Double.NaN;
60-
} else {
61-
return sum / notNACount;
60+
private static final class TransformMean extends FinalTransform {
61+
private static final TransformMean INSTANCE = new TransformMean();
62+
63+
@Override
64+
public double get(double sum, int notNACount) {
65+
if (Double.isNaN(sum)) {
66+
return sum;
67+
} else if (notNACount == 0) {
68+
return Double.NaN;
69+
} else {
70+
return sum / notNACount;
71+
}
6272
}
6373
}
6474
}

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSums.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
2525

2626
import com.oracle.truffle.api.dsl.Specialization;
27+
import com.oracle.truffle.api.library.CachedLibrary;
2728
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
2829
import com.oracle.truffle.r.runtime.data.RDoubleVector;
2930
import com.oracle.truffle.r.runtime.data.RIntVector;
3031
import com.oracle.truffle.r.runtime.data.RLogicalVector;
32+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3133

3234
@RBuiltin(name = "rowSums", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE)
3335
public abstract class RowSums extends RowSumsBase {
@@ -36,18 +38,30 @@ public abstract class RowSums extends RowSumsBase {
3638
createCasts(RowSums.class);
3739
}
3840

39-
@Specialization
40-
protected RDoubleVector rowSums(RDoubleVector x, int rowNum, int colNum, boolean naRm) {
41-
return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> v.getDataAt(i));
41+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
42+
protected RDoubleVector rowSums(RDoubleVector x, int rowNum, int colNum, boolean naRm,
43+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
44+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, SelectSum.INSTANCE);
4245
}
4346

44-
@Specialization
45-
protected RDoubleVector rowSums(RIntVector x, int rowNum, int colNum, boolean naRm) {
46-
return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> nacheck.convertIntToDouble(v.getDataAt(i)));
47+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
48+
protected RDoubleVector rowSums(RIntVector x, int rowNum, int colNum, boolean naRm,
49+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
50+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, SelectSum.INSTANCE);
4751
}
4852

49-
@Specialization
50-
protected RDoubleVector rowSums(RLogicalVector x, int rowNum, int colNum, boolean naRm) {
51-
return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> nacheck.convertLogicalToDouble(v.getDataAt(i)));
53+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
54+
protected RDoubleVector rowSums(RLogicalVector x, int rowNum, int colNum, boolean naRm,
55+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
56+
return accumulateRows(dataLib, x.getData(), rowNum, colNum, naRm, SelectSum.INSTANCE);
57+
}
58+
59+
private static final class SelectSum extends FinalTransform {
60+
private static final SelectSum INSTANCE = new SelectSum();
61+
62+
@Override
63+
public double get(double sum, int notNACount) {
64+
return sum;
65+
}
5266
}
5367
}

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSumsBase.java

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2013, 2018, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2013, 2020, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -27,7 +27,8 @@
2727
import com.oracle.truffle.r.runtime.RRuntime;
2828
import com.oracle.truffle.r.runtime.data.RDataFactory;
2929
import com.oracle.truffle.r.runtime.data.RDoubleVector;
30-
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
30+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
31+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary.RandomAccessIterator;
3132
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
3233
import com.oracle.truffle.r.runtime.ops.na.NACheck;
3334

@@ -48,31 +49,26 @@ public abstract class RowSumsBase extends ColSumsBase {
4849
private final LoopConditionProfile outerProfile = LoopConditionProfile.createCountingProfile();
4950
private final LoopConditionProfile innerProfile = LoopConditionProfile.createCountingProfile();
5051

51-
@FunctionalInterface
52-
protected interface GetFunction<T extends RAbstractVector> {
53-
double get(T vector, NACheck na, int index);
52+
protected abstract static class FinalTransform {
53+
abstract double get(double sum, int notNACount);
5454
}
5555

56-
@FunctionalInterface
57-
protected interface FinalTransform {
58-
double get(double sum, int notNACount);
59-
}
60-
61-
protected final <T extends RAbstractVector> RDoubleVector accumulateRows(T x, int rowNum, int colNum, boolean naRm, FinalTransform finalTransform, RowSumsBase.GetFunction<T> get) {
62-
reportWork(x.getLength());
56+
protected final RDoubleVector accumulateRows(VectorDataLibrary dataLib, Object data, int rowNum, int colNum, boolean naRm, FinalTransform finalTransform) {
57+
reportWork(dataLib.getLength(data));
6358
double[] result = new double[rowNum];
64-
na.enable(x);
6559
outerProfile.profileCounted(rowNum / 4);
6660
innerProfile.profileCounted(colNum);
6761
int i = 0;
62+
RandomAccessIterator it = dataLib.randomAccessIterator(data);
6863
// the unrolled loop cannot handle NA values
64+
NACheck na = dataLib.getNACheck(data);
6965
if (!na.isEnabled()) {
7066
while (outerProfile.inject(i <= rowNum - UNROLL)) {
7167
double[] sum = new double[UNROLL];
7268
int pos = i;
7369
for (int c = 0; innerProfile.inject(c < colNum); c++) {
7470
for (int unroll = 0; unroll < UNROLL; unroll++) {
75-
sum[unroll] = add.op(sum[unroll], get.get(x, na, pos + unroll));
71+
sum[unroll] = add.op(sum[unroll], dataLib.getDouble(data, it, pos + unroll));
7672
}
7773
pos += rowNum;
7874
}
@@ -88,7 +84,7 @@ protected final <T extends RAbstractVector> RDoubleVector accumulateRows(T x, in
8884
int pos = i;
8985
int notNACount = 0;
9086
for (int c = 0; innerProfile.inject(c < colNum); c++) {
91-
double el = get.get(x, na, pos);
87+
double el = dataLib.getDouble(data, it, pos);
9288
pos += rowNum;
9389
if (na.check(el)) {
9490
if (!naRm) {
@@ -109,6 +105,6 @@ protected final <T extends RAbstractVector> RDoubleVector accumulateRows(T x, in
109105
i++;
110106
}
111107
}
112-
return RDataFactory.createDoubleVector(result, na.neverSeenNA());
108+
return RDataFactory.createDoubleVector(result, dataLib.getNACheck(data).neverSeenNA());
113109
}
114110
}

0 commit comments

Comments
 (0)