Skip to content

Commit 2240154

Browse files
committed
[GR-22670] Rewrite to use VectorDataLibrary: Quantifier, one more case in CachedExtractVectorNode
(cherry picked from commit 24b600a)
1 parent 7f05e8a commit 2240154

File tree

3 files changed

+56
-29
lines changed

3 files changed

+56
-29
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Quantifier.java

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,16 @@
3333
import com.oracle.truffle.api.CompilerDirectives;
3434
import com.oracle.truffle.api.dsl.Cached;
3535
import com.oracle.truffle.api.dsl.Specialization;
36+
import com.oracle.truffle.api.library.CachedLibrary;
3637
import com.oracle.truffle.api.nodes.ExplodeLoop;
3738
import com.oracle.truffle.api.profiles.BranchProfile;
3839
import com.oracle.truffle.api.profiles.ConditionProfile;
3940
import com.oracle.truffle.api.profiles.ValueProfile;
4041
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
42+
import com.oracle.truffle.r.nodes.builtin.base.QuantifierNodeGen.ProcessArgumentNodeGen;
43+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
44+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary.SeqIterator;
45+
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
4146
import com.oracle.truffle.r.runtime.nodes.unary.CastNode;
4247
import com.oracle.truffle.r.runtime.DSLConfig;
4348
import com.oracle.truffle.r.runtime.RError;
@@ -48,16 +53,13 @@
4853
import com.oracle.truffle.r.runtime.data.RNull;
4954
import com.oracle.truffle.r.runtime.data.RLogicalVector;
5055
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
51-
import com.oracle.truffle.r.runtime.ops.na.NACheck;
5256

5357
public abstract class Quantifier extends RBuiltinNode.Arg2 {
5458
protected static final int MAX_CACHED_LENGTH = 10;
5559

56-
private final NACheck naCheck = NACheck.create();
57-
private final BranchProfile trueBranch = BranchProfile.create();
58-
private final BranchProfile falseBranch = BranchProfile.create();
59-
6060
@Children private final CastNode[] argCastNodes = new CastNode[Math.max(1, DSLConfig.getCacheSize(MAX_CACHED_LENGTH))];
61+
@Children private final ProcessArgumentNode[] processArgumentNodes = new ProcessArgumentNode[Math.max(1, DSLConfig.getCacheSize(MAX_CACHED_LENGTH))];
62+
private final BranchProfile nullBranch = BranchProfile.create();
6163

6264
private static final class ProfileCastNode extends CastNode {
6365

@@ -141,37 +143,61 @@ protected byte op(RArgsValuesAndNames args, boolean naRm,
141143
}
142144

143145
private byte processArgument(Object argValue, int index, boolean naRm) {
144-
byte result = RRuntime.asLogical(emptyVectorResult());
145146
if (argValue != RNull.instance) {
146147
if (argCastNodes[index] == null) {
147148
CompilerDirectives.transferToInterpreterAndInvalidate();
148149
createArgCast(index);
149150
}
150151
Object castValue = argCastNodes[index].doCast(argValue);
151-
if (castValue instanceof RLogicalVector) {
152-
RLogicalVector vector = (RLogicalVector) castValue;
153-
naCheck.enable(vector);
154-
for (int i = 0; i < vector.getLength(); i++) {
155-
byte b = vector.getDataAt(i);
156-
if (!naRm && naCheck.check(b)) {
157-
result = RRuntime.LOGICAL_NA;
158-
} else if (b == RRuntime.asLogical(!emptyVectorResult())) {
159-
trueBranch.enter();
160-
return RRuntime.asLogical(!emptyVectorResult());
161-
}
162-
}
163-
} else {
164-
byte b = (byte) castValue;
165-
naCheck.enable(true);
166-
if (!naRm && naCheck.check(b)) {
152+
if (processArgumentNodes[index] == null) {
153+
CompilerDirectives.transferToInterpreterAndInvalidate();
154+
processArgumentNodes[index] = insert(ProcessArgumentNodeGen.create(this));
155+
}
156+
return processArgumentNodes[index].execute(castValue, naRm);
157+
}
158+
nullBranch.enter();
159+
return RRuntime.asLogical(emptyVectorResult());
160+
}
161+
162+
abstract static class ProcessArgumentNode extends RBaseNode {
163+
private final Quantifier parent;
164+
165+
protected ProcessArgumentNode(Quantifier parent) {
166+
this.parent = parent;
167+
}
168+
169+
public abstract byte execute(Object vector, boolean naRm);
170+
171+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
172+
byte doVector(RLogicalVector vector, boolean naRm,
173+
@CachedLibrary("vector.getData()") VectorDataLibrary argDataLib) {
174+
Object data = vector.getData();
175+
SeqIterator it = argDataLib.iterator(data);
176+
byte result = RRuntime.asLogical(parent.emptyVectorResult());
177+
while (argDataLib.next(data, it)) {
178+
byte b = argDataLib.getNextLogical(data, it);
179+
if (!naRm && argDataLib.getNACheck(data).check(b)) {
167180
result = RRuntime.LOGICAL_NA;
168-
} else if (b == RRuntime.asLogical(!emptyVectorResult())) {
169-
trueBranch.enter();
170-
return RRuntime.asLogical(!emptyVectorResult());
181+
} else if (b == RRuntime.asLogical(!parent.emptyVectorResult())) {
182+
return RRuntime.asLogical(!parent.emptyVectorResult());
171183
}
172184
}
185+
return result;
186+
}
187+
188+
@Specialization
189+
byte doSingleByte(byte value, boolean naRm,
190+
@Cached BranchProfile isNAProfile,
191+
@Cached BranchProfile trueBranchProfile) {
192+
if (!naRm && RRuntime.isNA(value)) {
193+
isNAProfile.enter();
194+
return RRuntime.LOGICAL_NA;
195+
} else if (value == RRuntime.asLogical(!parent.emptyVectorResult())) {
196+
return RRuntime.asLogical(!parent.emptyVectorResult());
197+
} else {
198+
trueBranchProfile.enter();
199+
return RRuntime.asLogical(parent.emptyVectorResult());
200+
}
173201
}
174-
falseBranch.enter();
175-
return result;
176202
}
177203
}

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ public Object apply(RAbstractContainer originalVector, Object[] originalPosition
216216
writeVectorNode.execute(extractedVector, positions, vector, dimensions);
217217
RBaseNode.reportWork(this, 1);
218218
assert extractedVectorLength == 1;
219-
return extractedVector.getDataAtAsObject(0);
219+
final Object extractedVecData = extractedVector.getData();
220+
return getExtractedVectorDataLib(extractedVecData).getDataAtAsObject(extractedVecData, 0);
220221
}
221222
}
222223

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanScalarNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2013, 2019, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2013, 2020, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it

0 commit comments

Comments
 (0)