Skip to content

Commit 7f05e8a

Browse files
committed
[GR-22670] Rewrite to use vector data lib: Colon, logical cast in CachedVectorNode, BinaryBooleanScalarNode
(cherry picked from commit 1af8dca)
1 parent 8880fe5 commit 7f05e8a

File tree

4 files changed

+63
-69
lines changed

4 files changed

+63
-69
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Colon.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.oracle.truffle.api.dsl.Fallback;
3030
import com.oracle.truffle.api.dsl.NodeChild;
3131
import com.oracle.truffle.api.dsl.Specialization;
32+
import com.oracle.truffle.api.library.CachedLibrary;
3233
import com.oracle.truffle.api.nodes.NodeCost;
3334
import com.oracle.truffle.api.nodes.NodeInfo;
3435
import com.oracle.truffle.api.profiles.ConditionProfile;
@@ -37,6 +38,7 @@
3738
import com.oracle.truffle.r.nodes.builtin.base.Colon.ColonInternal;
3839
import com.oracle.truffle.r.nodes.builtin.base.ColonNodeGen.ColonCastNodeGen;
3940
import com.oracle.truffle.r.nodes.builtin.base.ColonNodeGen.ColonInternalNodeGen;
41+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
4042
import com.oracle.truffle.r.runtime.nodes.unary.CastNode;
4143
import com.oracle.truffle.r.runtime.ArgumentsSignature;
4244
import com.oracle.truffle.r.runtime.RError;
@@ -244,22 +246,28 @@ private void checkLength(int length) {
244246
}
245247
}
246248

247-
@Specialization
248-
protected int doSequence(RIntVector vector) {
249-
checkLength(vector.getLength());
250-
return vector.getDataAt(0);
249+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
250+
protected int doSequence(RIntVector vector,
251+
@CachedLibrary("vector.getData()") VectorDataLibrary dataLib) {
252+
Object data = vector.getData();
253+
checkLength(dataLib.getLength(data));
254+
return dataLib.getIntAt(data, 0);
251255
}
252256

253-
@Specialization(guards = "isFirstIntValue(vector)")
254-
protected int doDoubleVectorFirstIntValue(RDoubleVector vector) {
255-
checkLength(vector.getLength());
256-
return (int) vector.getDataAt(0);
257+
@Specialization(guards = "isFirstIntValue(dataLib, vector.getData())", limit = "getTypedVectorDataLibraryCacheSize()")
258+
protected int doDoubleVectorFirstIntValue(RDoubleVector vector,
259+
@CachedLibrary("vector.getData()") VectorDataLibrary dataLib) {
260+
Object data = vector.getData();
261+
checkLength(dataLib.getLength(data));
262+
return (int) dataLib.getDoubleAt(data, 0);
257263
}
258264

259-
@Specialization(guards = "!isFirstIntValue(vector)")
260-
protected double doDoubleVector(RDoubleVector vector) {
261-
checkLength(vector.getLength());
262-
return vector.getDataAt(0);
265+
@Specialization(guards = "!isFirstIntValue(dataLib, vector.getData())", limit = "getTypedVectorDataLibraryCacheSize()")
266+
protected double doDoubleVector(RDoubleVector vector,
267+
@CachedLibrary("vector.getData()") VectorDataLibrary dataLib) {
268+
Object data = vector.getData();
269+
checkLength(dataLib.getLength(data));
270+
return dataLib.getDoubleAt(data, 0);
263271
}
264272

265273
@Specialization
@@ -292,8 +300,8 @@ protected static boolean isIntValue(double d) {
292300
return (((int) d)) == d && !RRuntime.isNA((int) d);
293301
}
294302

295-
protected static boolean isFirstIntValue(RDoubleVector d) {
296-
return d.getLength() > 0 && isIntValue(d.getDataAt(0));
303+
protected static boolean isFirstIntValue(VectorDataLibrary dataLib, Object data) {
304+
return dataLib.getLength(data) > 0 && isIntValue(dataLib.getDoubleAt(data, 0));
297305
}
298306
}
299307
}

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ final class CachedExtractVectorNode extends CachedVectorNode {
119119
Object[] convertedPositions = filterPositions(positions);
120120
this.extractNames = new CachedExtractVectorNode[convertedPositions.length];
121121
this.extractNamesAlternative = new CachedExtractVectorNode[convertedPositions.length];
122-
this.exact = logicalAsBoolean(exact, DEFAULT_EXACT);
123-
this.dropDimensions = logicalAsBoolean(dropDimensions, DEFAULT_DROP_DIMENSION);
122+
this.exact = logicalAsBoolean(VectorDataLibrary.getFactory().getUncached(), exact, DEFAULT_EXACT);
123+
this.dropDimensions = logicalAsBoolean(VectorDataLibrary.getFactory().getUncached(), dropDimensions, DEFAULT_DROP_DIMENSION);
124124
this.positionsCheckNode = new PositionsCheckNode(mode, vectorType, convertedPositions, this.exact, false, recursive);
125125
this.writeVectorNode = WriteIndexedVectorNode.create(vectorType, convertedPositions.length, false, false);
126126
this.droppedDimensionProfile = this.dropDimensions ? ConditionProfile.createBinaryProfile() : null;
@@ -130,8 +130,8 @@ final class CachedExtractVectorNode extends CachedVectorNode {
130130

131131
public boolean isSupported(Object target, Object[] positions, Object exactValue, Object dropDimensionsValue) {
132132
if (targetClass == target.getClass() && getDataClass(target) == targetDataClass && exactValue.getClass() == this.exactClass && dropDimensionsValue.getClass() == dropDimensionsClass //
133-
&& logicalAsBoolean(dropDimensionsClass.cast(dropDimensionsValue), DEFAULT_DROP_DIMENSION) == this.dropDimensions //
134-
&& logicalAsBoolean(exactClass.cast(exactValue), DEFAULT_EXACT) == this.exact) {
133+
&& logicalAsBoolean(getAsLogicalVectorDataLib(), dropDimensionsClass.cast(dropDimensionsValue), DEFAULT_DROP_DIMENSION) == this.dropDimensions //
134+
&& logicalAsBoolean(getAsLogicalVectorDataLib(), exactClass.cast(exactValue), DEFAULT_EXACT) == this.exact) {
135135
return positionsCheckNode.isSupported(positions);
136136
}
137137
return false;

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedVectorNode.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
*/
2323
package com.oracle.truffle.r.nodes.access.vector;
2424

25+
import com.oracle.truffle.api.CompilerDirectives;
26+
import com.oracle.truffle.r.runtime.DSLConfig;
27+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
2528
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
2629
import com.oracle.truffle.r.runtime.RRuntime;
2730
import com.oracle.truffle.r.runtime.RType;
@@ -54,6 +57,7 @@ abstract class CachedVectorNode extends RBaseNodeWithWarnings {
5457
private final int filteredPositionsLength;
5558

5659
@Child private GetDimAttributeNode getDimNode = GetDimAttributeNode.create();
60+
@Child private VectorDataLibrary asLogicalVectorDataLib;
5761

5862
CachedVectorNode(ElementAccessMode mode, RBaseObject vector, Object[] positions, boolean recursive) {
5963
this.mode = mode;
@@ -113,19 +117,28 @@ private static boolean isRemovePosition(Object position) {
113117
return position instanceof RMissing;
114118
}
115119

116-
protected static boolean logicalAsBoolean(RBaseObject cast, boolean defaultValue) {
120+
protected static boolean logicalAsBoolean(VectorDataLibrary dataLib, RBaseObject cast, boolean defaultValue) {
117121
if (cast instanceof RMissing) {
118122
return defaultValue;
119123
} else {
120124
RLogicalVector logical = (RLogicalVector) cast;
121-
if (logical.getLength() == 0) {
125+
Object data = logical.getData();
126+
if (dataLib.getLength(data) == 0) {
122127
return defaultValue;
123128
} else {
124-
return RRuntime.fromLogical(logical.getDataAt(0));
129+
return RRuntime.fromLogical(dataLib.getLogicalAt(data, 0));
125130
}
126131
}
127132
}
128133

134+
protected VectorDataLibrary getAsLogicalVectorDataLib() {
135+
if (asLogicalVectorDataLib == null) {
136+
CompilerDirectives.transferToInterpreterAndInvalidate();
137+
asLogicalVectorDataLib = insert(VectorDataLibrary.getFactory().createDispatched(DSLConfig.getTypedVectorDataLibraryCacheSize()));
138+
}
139+
return asLogicalVectorDataLib;
140+
}
141+
129142
protected final int[] loadVectorDimensions(RAbstractContainer vector) {
130143
// N.B. (stepan) this method used to be instance method and have special handling for
131144
// RDataFrame, which was removed and any test case, which would require this special

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanScalarNode.java

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
*/
2323
package com.oracle.truffle.r.nodes.binary;
2424

25-
import com.oracle.truffle.api.CompilerAsserts;
26-
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
2725
import com.oracle.truffle.api.dsl.Cached;
2826
import com.oracle.truffle.api.dsl.Fallback;
2927
import com.oracle.truffle.api.dsl.ImportStatic;
@@ -41,8 +39,7 @@
4139
import com.oracle.truffle.r.runtime.RError;
4240
import com.oracle.truffle.r.runtime.RInternalError;
4341
import com.oracle.truffle.r.runtime.RRuntime;
44-
import com.oracle.truffle.r.runtime.RType;
45-
import com.oracle.truffle.r.runtime.data.RComplex;
42+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
4643
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
4744
import com.oracle.truffle.r.runtime.interop.ConvertForeignObjectNode;
4845
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
@@ -92,65 +89,41 @@ protected byte binary(VirtualFrame frame, Object leftValue, Object rightValue) {
9289
@ImportStatic({RRuntime.class, DSLConfig.class, ConvertForeignObjectNode.class})
9390
protected abstract static class LogicalScalarCastNode extends RBaseNode {
9491

95-
protected static final int CACHE_LIMIT = 3;
96-
9792
public abstract byte executeCast(Object o);
9893

9994
private final String opName;
10095
private final String argumentName;
10196

10297
private final NACheck check;
103-
private final BranchProfile seenEmpty = BranchProfile.create();
10498

10599
LogicalScalarCastNode(String opName, String argumentName, NACheck check) {
106100
this.opName = opName;
107101
this.argumentName = argumentName;
108102
this.check = check;
109103
}
110104

111-
@Specialization(limit = "getCacheSize(CACHE_LIMIT)", guards = {"cachedClass != null", "operand.getClass() == cachedClass"})
112-
protected byte doCached(Object operand,
113-
@Cached("getNumericVectorClass(operand)") Class<? extends RAbstractVector> cachedClass) {
114-
return castImpl(cachedClass.cast(operand));
115-
}
116-
117-
@Specialization(replaces = "doCached", guards = {"operand.getRType().isNumeric()"})
118-
@TruffleBoundary
119-
protected byte doGeneric(RAbstractVector operand) {
120-
return castImpl(operand);
121-
}
122-
123-
private byte castImpl(RAbstractVector vector) {
124-
if (vector.getLength() == 0) {
125-
seenEmpty.enter();
126-
this.check.enable(true);
127-
return RRuntime.LOGICAL_NA;
128-
}
129-
this.check.enable(!vector.isComplete());
130-
RType type = vector.getRType();
131-
CompilerAsserts.compilationConstant(type);
132-
switch (type) {
133-
case Logical:
134-
return (byte) vector.getDataAtAsObject(0);
135-
case Integer:
136-
return check.convertIntToLogical((int) vector.getDataAtAsObject(0));
137-
case Double:
138-
return check.convertDoubleToLogical((double) vector.getDataAtAsObject(0));
139-
case Complex:
140-
return check.convertComplexToLogical((RComplex) vector.getDataAtAsObject(0));
141-
default:
142-
throw RInternalError.shouldNotReachHere();
143-
}
144-
}
145-
146-
protected static Class<? extends RAbstractVector> getNumericVectorClass(Object value) {
147-
if (value instanceof RAbstractVector) {
148-
RAbstractVector castVector = (RAbstractVector) value;
149-
if (castVector.getRType().isNumeric()) {
150-
return castVector.getClass();
105+
@Specialization(limit = "getGenericDataLibraryCacheSize()")
106+
protected byte doCached(RAbstractVector operand,
107+
@Cached BranchProfile isNotNumeric,
108+
@Cached BranchProfile seenEmpty,
109+
@CachedLibrary("operand.getData()") VectorDataLibrary dataLib) {
110+
Object data = operand.getData();
111+
if (!dataLib.getType(data).isNumeric()) {
112+
isNotNumeric.enter();
113+
doFallback(operand);
114+
throw RInternalError.shouldNotReachHere();
115+
} else {
116+
if (dataLib.getLength(data) == 0) {
117+
seenEmpty.enter();
118+
check.enable(true);
119+
check.check(RRuntime.LOGICAL_NA);
120+
return RRuntime.LOGICAL_NA;
151121
}
122+
check.enable(!dataLib.isComplete(data));
123+
byte result = dataLib.getLogicalAt(data, 0);
124+
check.check(result); // we need to update the neverSeenNA flag
125+
return result;
152126
}
153-
return null;
154127
}
155128

156129
@Specialization(guards = {"isForeignArray(operand, interop)"}, limit = "getInteropLibraryCacheSize()")

0 commit comments

Comments
 (0)