Skip to content

Commit 2959340

Browse files
committed
[GR-22670] Rewrite ColMeans to use VectorDataLibrary
(cherry picked from commit 2d9c45f)
1 parent fcbc5ee commit 2959340

File tree

1 file changed

+51
-31
lines changed
  • com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base

1 file changed

+51
-31
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,42 @@
2424
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
2525

2626
import com.oracle.truffle.api.dsl.Specialization;
27+
import com.oracle.truffle.api.library.CachedLibrary;
2728
import com.oracle.truffle.r.runtime.RRuntime;
2829
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
2930
import com.oracle.truffle.r.runtime.data.RDataFactory;
3031
import com.oracle.truffle.r.runtime.data.RDoubleVector;
3132
import com.oracle.truffle.r.runtime.data.RIntVector;
3233
import com.oracle.truffle.r.runtime.data.RLogicalVector;
34+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
35+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary.RandomAccessIterator;
3336
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
3437
import com.oracle.truffle.r.runtime.ops.na.NACheck;
3538

3639
//Implements .colMeans
3740
@RBuiltin(name = "colMeans", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE)
3841
public abstract class ColMeans extends ColSumsBase {
3942

40-
protected final NACheck na = NACheck.create();
4143
@Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation();
4244

4345
static {
4446
createCasts(ColMeans.class);
4547
}
4648

47-
@Specialization(guards = "!naRm")
48-
protected RDoubleVector colMeansNaRmFalse(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
49-
checkVectorLength(x, rowNum, colNum);
49+
@Specialization(guards = "!naRm", limit = "getTypedVectorDataLibraryCacheSize()")
50+
protected RDoubleVector colMeansNaRmFalse(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
51+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
52+
checkVectorLength(dataLib, x, rowNum, colNum);
5053

5154
double[] result = new double[colNum];
5255
boolean isComplete = true;
53-
na.enable(x);
56+
Object xData = x.getData();
57+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
58+
NACheck na = dataLib.getNACheck(xData);
5459
nextCol: for (int c = 0; c < colNum; c++) {
5560
double sum = 0;
5661
for (int i = 0; i < rowNum; i++) {
57-
double el = x.getDataAt(c * rowNum + i);
62+
double el = dataLib.getDouble(xData, it, c * rowNum + i);
5863
if (na.check(el)) {
5964
result[c] = RRuntime.DOUBLE_NA;
6065
continue nextCol;
@@ -71,18 +76,21 @@ protected RDoubleVector colMeansNaRmFalse(RDoubleVector x, int rowNum, int colNu
7176
return RDataFactory.createDoubleVector(result, na.neverSeenNA() && isComplete);
7277
}
7378

74-
@Specialization(guards = "naRm")
75-
protected RDoubleVector colMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
76-
checkVectorLength(x, rowNum, colNum);
79+
@Specialization(guards = "naRm", limit = "getTypedVectorDataLibraryCacheSize()")
80+
protected RDoubleVector colMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
81+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
82+
checkVectorLength(dataLib, x, rowNum, colNum);
7783

7884
double[] result = new double[colNum];
7985
boolean isComplete = true;
80-
na.enable(x);
86+
Object xData = x.getData();
87+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
88+
NACheck na = dataLib.getNACheck(xData);
8189
for (int c = 0; c < colNum; c++) {
8290
double sum = 0;
8391
int nonNaNumCount = 0;
8492
for (int i = 0; i < rowNum; i++) {
85-
double el = x.getDataAt(c * rowNum + i);
93+
double el = dataLib.getDouble(xData, it, c * rowNum + i);
8694
if (!na.check(el) && !Double.isNaN(el)) {
8795
sum = add.op(sum, el);
8896
nonNaNumCount++;
@@ -98,16 +106,19 @@ protected RDoubleVector colMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum
98106
return RDataFactory.createDoubleVector(result, isComplete);
99107
}
100108

101-
@Specialization(guards = "!naRm")
102-
protected RDoubleVector colMeansNaRmFalse(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
103-
checkVectorLength(x, rowNum, colNum);
109+
@Specialization(guards = "!naRm", limit = "getTypedVectorDataLibraryCacheSize()")
110+
protected RDoubleVector colMeansNaRmFalse(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
111+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
112+
checkVectorLength(dataLib, x, rowNum, colNum);
104113

105114
double[] result = new double[colNum];
106-
na.enable(x);
115+
Object xData = x.getData();
116+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
117+
NACheck na = dataLib.getNACheck(xData);
107118
nextCol: for (int c = 0; c < colNum; c++) {
108119
double sum = 0;
109120
for (int i = 0; i < rowNum; i++) {
110-
byte el = x.getDataAt(c * rowNum + i);
121+
byte el = dataLib.getLogical(xData, it, c * rowNum + i);
111122
if (na.check(el)) {
112123
result[c] = RRuntime.DOUBLE_NA;
113124
continue nextCol;
@@ -119,18 +130,21 @@ protected RDoubleVector colMeansNaRmFalse(RLogicalVector x, int rowNum, int colN
119130
return RDataFactory.createDoubleVector(result, na.neverSeenNA());
120131
}
121132

122-
@Specialization(guards = "naRm")
123-
protected RDoubleVector colMeansNaRmTrue(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
124-
checkVectorLength(x, rowNum, colNum);
133+
@Specialization(guards = "naRm", limit = "getTypedVectorDataLibraryCacheSize()")
134+
protected RDoubleVector colMeansNaRmTrue(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
135+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
136+
checkVectorLength(dataLib, x, rowNum, colNum);
125137

126138
double[] result = new double[colNum];
127139
boolean isComplete = true;
128-
na.enable(x);
140+
Object xData = x.getData();
141+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
142+
NACheck na = dataLib.getNACheck(xData);
129143
for (int c = 0; c < colNum; c++) {
130144
double sum = 0;
131145
int nonNaNumCount = 0;
132146
for (int i = 0; i < rowNum; i++) {
133-
byte el = x.getDataAt(c * rowNum + i);
147+
byte el = dataLib.getLogical(xData, it, c * rowNum + i);
134148
if (!na.check(el)) {
135149
sum = add.op(sum, el);
136150
nonNaNumCount++;
@@ -146,16 +160,19 @@ protected RDoubleVector colMeansNaRmTrue(RLogicalVector x, int rowNum, int colNu
146160
return RDataFactory.createDoubleVector(result, isComplete);
147161
}
148162

149-
@Specialization(guards = "!naRm")
150-
protected RDoubleVector colMeansNaRmFalse(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
151-
checkVectorLength(x, rowNum, colNum);
163+
@Specialization(guards = "!naRm", limit = "getTypedVectorDataLibraryCacheSize()")
164+
protected RDoubleVector colMeansNaRmFalse(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
165+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
166+
checkVectorLength(dataLib, x, rowNum, colNum);
152167

153168
double[] result = new double[colNum];
154-
na.enable(x);
169+
Object xData = x.getData();
170+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
171+
NACheck na = dataLib.getNACheck(xData);
155172
nextCol: for (int c = 0; c < colNum; c++) {
156173
double sum = 0;
157174
for (int i = 0; i < rowNum; i++) {
158-
int el = x.getDataAt(c * rowNum + i);
175+
int el = dataLib.getInt(xData, it, c * rowNum + i);
159176
if (na.check(el)) {
160177
result[c] = RRuntime.DOUBLE_NA;
161178
continue nextCol;
@@ -167,18 +184,21 @@ protected RDoubleVector colMeansNaRmFalse(RIntVector x, int rowNum, int colNum,
167184
return RDataFactory.createDoubleVector(result, na.neverSeenNA());
168185
}
169186

170-
@Specialization(guards = "naRm")
171-
protected RDoubleVector colMeansNaRmTrue(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) {
172-
checkVectorLength(x, rowNum, colNum);
187+
@Specialization(guards = "naRm", limit = "getTypedVectorDataLibraryCacheSize()")
188+
protected RDoubleVector colMeansNaRmTrue(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm,
189+
@CachedLibrary("x.getData()") VectorDataLibrary dataLib) {
190+
checkVectorLength(dataLib, x, rowNum, colNum);
173191

174192
double[] result = new double[colNum];
175193
boolean isComplete = true;
176-
na.enable(x);
194+
Object xData = x.getData();
195+
RandomAccessIterator it = dataLib.randomAccessIterator(xData);
196+
NACheck na = dataLib.getNACheck(xData);
177197
for (int c = 0; c < colNum; c++) {
178198
double sum = 0;
179199
int nonNaNumCount = 0;
180200
for (int i = 0; i < rowNum; i++) {
181-
int el = x.getDataAt(c * rowNum + i);
201+
int el = dataLib.getInt(xData, it, c * rowNum + i);
182202
if (!na.check(el)) {
183203
sum = add.op(sum, el);
184204
nonNaNumCount++;

0 commit comments

Comments
 (0)