Skip to content

Commit 55e8ece

Browse files
Palezsteve-s
authored andcommitted
Implemented the rapply builtin function, and added some tests for it. Updated InheritsCheckNode to be able to create ClassHierarchyNode that contains implicit classes.
1 parent bed01e3 commit 55e8ece

File tree

12 files changed

+600
-10
lines changed

12 files changed

+600
-10
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ public BasePackage() {
629629
add(LaFunctions.LaLibrary.class, LaFunctionsFactory.LaLibraryNodeGen::create);
630630
add(LaFunctions.Backsolve.class, LaFunctionsFactory.BacksolveNodeGen::create);
631631
add(Lapply.class, LapplyNodeGen::create);
632+
add(Rapply.class, RapplyNodeGen::create);
632633
add(Length.class, LengthNodeGen::create);
633634
add(Lengths.class, LengthsNodeGen::create);
634635
add(License.class, LicenseNodeGen::create);

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsTypeFunctions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ public abstract static class IsNumeric extends RBuiltinNode.Arg1 {
429429
createCasts(IsNumeric.class);
430430
}
431431

432-
@Child private InheritsCheckNode inheritsCheck = new InheritsCheckNode(RRuntime.CLASS_FACTOR);
432+
@Child private InheritsCheckNode inheritsCheck = InheritsCheckNode.create(RRuntime.CLASS_FACTOR);
433433

434434
protected boolean isFactor(Object o) {
435435
return inheritsCheck.execute(o);

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Match.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ protected RIntVector match(RAbstractVector x, RNull table, int nomatch, Object i
110110
return RDataFactory.createIntVector(data, na.profile(!RRuntime.isNA(nomatch)));
111111
}
112112

113-
@Child private InheritsCheckNode factorInheritsCheck = new InheritsCheckNode(RRuntime.CLASS_FACTOR);
113+
@Child private InheritsCheckNode factorInheritsCheck = InheritsCheckNode.create(RRuntime.CLASS_FACTOR);
114114

115115
protected boolean isFactor(Object o) {
116116
return factorInheritsCheck.execute(o);
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/*
2+
* Copyright (c) 1995-2015, The R Core Team
3+
* Copyright (c) 2016, 2018, Oracle and/or its affiliates
4+
*
5+
* This program is free software; you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation; either version 2 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program; if not, a copy is available at
17+
* https://www.R-project.org/Licenses/
18+
*/
19+
package com.oracle.truffle.r.nodes.builtin.base;
20+
21+
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.anyValue;
22+
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.constant;
23+
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullConstant;
24+
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
25+
import static com.oracle.truffle.r.nodes.builtin.base.Lapply.createCallSourceSection;
26+
import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX;
27+
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
28+
29+
import com.oracle.truffle.api.CompilerAsserts;
30+
import com.oracle.truffle.api.CompilerDirectives;
31+
import com.oracle.truffle.api.dsl.Cached;
32+
import com.oracle.truffle.api.dsl.Specialization;
33+
import com.oracle.truffle.api.frame.Frame;
34+
import com.oracle.truffle.api.frame.FrameSlot;
35+
import com.oracle.truffle.api.frame.FrameSlotKind;
36+
import com.oracle.truffle.api.frame.FrameSlotTypeException;
37+
import com.oracle.truffle.api.frame.VirtualFrame;
38+
import com.oracle.truffle.api.profiles.LoopConditionProfile;
39+
import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
40+
import com.oracle.truffle.r.nodes.access.vector.ElementAccessMode;
41+
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNode;
42+
import com.oracle.truffle.r.nodes.access.vector.ExtractVectorNodeGen;
43+
import com.oracle.truffle.r.nodes.attributes.UnaryCopyAttributesNode;
44+
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
45+
import com.oracle.truffle.r.nodes.builtin.base.RapplyNodeGen.RapplyInternalNodeGen;
46+
import com.oracle.truffle.r.nodes.control.RLengthNode;
47+
import com.oracle.truffle.r.nodes.function.RCallBaseNode;
48+
import com.oracle.truffle.r.nodes.function.RCallNode;
49+
import com.oracle.truffle.r.nodes.unary.InheritsNode;
50+
import com.oracle.truffle.r.nodes.unary.InheritsNodeGen;
51+
import com.oracle.truffle.r.runtime.ArgumentsSignature;
52+
import com.oracle.truffle.r.runtime.RArguments;
53+
import com.oracle.truffle.r.runtime.RError.Message;
54+
import com.oracle.truffle.r.runtime.RInternalError;
55+
import com.oracle.truffle.r.runtime.RRuntime;
56+
import com.oracle.truffle.r.runtime.RType;
57+
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
58+
import com.oracle.truffle.r.runtime.context.RContext;
59+
import com.oracle.truffle.r.runtime.data.RDataFactory;
60+
import com.oracle.truffle.r.runtime.data.RFunction;
61+
import com.oracle.truffle.r.runtime.data.RList;
62+
import com.oracle.truffle.r.runtime.data.RListBase;
63+
import com.oracle.truffle.r.runtime.data.model.RAbstractListVector;
64+
import com.oracle.truffle.r.runtime.env.frame.FrameSlotChangeMonitor;
65+
import com.oracle.truffle.r.runtime.nodes.InternalRSyntaxNodeChildren;
66+
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
67+
import com.oracle.truffle.r.runtime.nodes.RNode;
68+
import com.oracle.truffle.r.runtime.nodes.RSourceSectionNode;
69+
import com.oracle.truffle.r.runtime.nodes.RSyntaxCall;
70+
import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
71+
import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
72+
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
73+
74+
@RBuiltin(name = "rapply", kind = INTERNAL, parameterNames = {"object", "f", "classes", "deflt", "how"}, splitCaller = true, behavior = COMPLEX)
75+
public abstract class Rapply extends RBuiltinNode.Arg5 {
76+
77+
@Child private RapplyInternalNode rapply = RapplyInternalNode.create();
78+
79+
static {
80+
Casts casts = new Casts(Rapply.class);
81+
casts.arg("object").mustBe(RAbstractListVector.class, Message.GENERIC, "'object' must be a list");
82+
casts.arg("f").mustBe(RFunction.class);
83+
casts.arg("classes").mapNull(constant("ANY")).mapMissing(constant("ANY")).mustBe(stringValue()).asStringVector().findFirst().mustNotBeNA();
84+
casts.arg("deflt").allowNull().mapMissing(nullConstant()).mustBe(anyValue());
85+
casts.arg("how").mapNull(constant("unlist")).mapMissing(constant("unlist")).mustBe(stringValue()).asStringVector().findFirst().mustNotBeNA();
86+
}
87+
88+
@Specialization(guards = "!isReplace(how)")
89+
protected Object rapplyReplace(VirtualFrame frame, RAbstractListVector object, RFunction f, String classes, Object deflt, String how, @Cached("create()") UnaryCopyAttributesNode attri) {
90+
91+
return attri.execute(RDataFactory.createList((Object[]) rapply.execute(frame, object, f, classes, deflt, how)), object);
92+
}
93+
94+
@Specialization(guards = "isReplace(how)")
95+
protected Object rapply(VirtualFrame frame, RAbstractListVector object, RFunction f, String classes, Object deflt, String how) {
96+
97+
return rapply.execute(frame, object, f, classes, deflt, how);
98+
}
99+
100+
protected static boolean isReplace(String how) {
101+
return RapplyInternalNode.isReplace(how);
102+
}
103+
104+
private static final class ExtractElementInternal extends RSourceSectionNode implements RSyntaxCall {
105+
106+
@Child private ExtractVectorNode extractElementNode = ExtractVectorNodeGen.create(ElementAccessMode.SUBSCRIPT, false);
107+
private final FrameSlot vectorSlot;
108+
private final FrameSlot indexSlot;
109+
110+
protected ExtractElementInternal(FrameSlot vectorSlot, FrameSlot indexSlot) {
111+
super(RSyntaxNode.LAZY_DEPARSE);
112+
this.vectorSlot = vectorSlot;
113+
this.indexSlot = indexSlot;
114+
}
115+
116+
@Override
117+
public Object execute(VirtualFrame frame) {
118+
RArguments.getCall(frame);
119+
try {
120+
return extractElementNode.apply(FrameSlotChangeMonitor.getObject(vectorSlot, frame), new Object[]{frame.getInt(indexSlot)}, RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_TRUE);
121+
} catch (FrameSlotTypeException e) {
122+
CompilerDirectives.transferToInterpreter();
123+
throw RInternalError.shouldNotReachHere("frame type mismatch in rapply");
124+
}
125+
}
126+
127+
@Override
128+
public RSyntaxElement getSyntaxLHS() {
129+
return RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "list", true);
130+
}
131+
132+
@Override
133+
public ArgumentsSignature getSyntaxSignature() {
134+
return ArgumentsSignature.empty(2);
135+
}
136+
137+
@Override
138+
public RSyntaxElement[] getSyntaxArguments() {
139+
return new RSyntaxElement[]{RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "object", false), RSyntaxLookup.createDummyLookup(LAZY_DEPARSE, "i", false)};
140+
}
141+
}
142+
143+
public abstract static class RapplyInternalNode extends RBaseNode implements InternalRSyntaxNodeChildren {
144+
145+
@Child private InheritsNode inheritsNode = InheritsNodeGen.create();
146+
@Child private RapplyInternalNode rapply;
147+
148+
protected static final String VECTOR_NAME = "object";
149+
protected static final String INDEX_NAME = "i";
150+
151+
public abstract Object execute(VirtualFrame frame, RAbstractListVector object, RFunction f, String classes, Object deflt, String how);
152+
153+
protected static FrameSlot createIndexSlot(Frame frame) {
154+
return FrameSlotChangeMonitor.findOrAddFrameSlot(frame.getFrameDescriptor(), INDEX_NAME, FrameSlotKind.Int);
155+
}
156+
157+
protected static FrameSlot createVectorSlot(Frame frame) {
158+
return FrameSlotChangeMonitor.findOrAddFrameSlot(frame.getFrameDescriptor(), VECTOR_NAME, FrameSlotKind.Object);
159+
}
160+
161+
@Specialization(guards = "isReplace(how)")
162+
protected RListBase cachedLapplyReplace(VirtualFrame frame, RAbstractListVector object, RFunction f, String classes, Object deflt, String how,
163+
@Cached("createIndexSlot(frame)") FrameSlot indexSlot,
164+
@Cached("createVectorSlot(frame)") FrameSlot vectorSlot,
165+
@Cached("create()") RLengthNode lengthNode,
166+
@Cached("createCountingProfile()") LoopConditionProfile loop,
167+
@Cached("createCallNode(vectorSlot, indexSlot)") RCallBaseNode callNode) {
168+
169+
int length = lengthNode.executeInteger(object);
170+
RListBase result = (RListBase) object.copy();
171+
FrameSlotChangeMonitor.setObject(frame, vectorSlot, object);
172+
173+
if (length > 0) {
174+
reportWork(this, length);
175+
loop.profileCounted(length);
176+
for (int i = 0; loop.inject(i < length); i++) {
177+
frame.setInt(indexSlot, i + 1);
178+
Object element = object.getDataAt(i);
179+
if (element instanceof RAbstractListVector) {
180+
result.setDataAt(i, getRapply().execute(frame, (RAbstractListVector) element, f, classes, deflt, how));
181+
FrameSlotChangeMonitor.setObject(frame, vectorSlot, object);
182+
} else if (isRNull(element)) {
183+
result.setDataAt(i, element);
184+
} else if (classes.equals("ANY") || inheritsNode.execute(element, RDataFactory.createStringVector(classes), false).equals(RRuntime.LOGICAL_TRUE)) {
185+
result.setDataAt(i, callNode.execute(frame, f));
186+
} else {
187+
result.setDataAt(i, element);
188+
}
189+
}
190+
}
191+
return result;
192+
}
193+
194+
private RapplyInternalNode getRapply() {
195+
if (rapply == null) {
196+
CompilerDirectives.transferToInterpreterAndInvalidate();
197+
rapply = insert(RapplyInternalNodeGen.create());
198+
}
199+
return rapply;
200+
}
201+
202+
@Specialization(guards = "!isReplace(how)")
203+
protected Object[] cachedLapply(VirtualFrame frame, RAbstractListVector object, RFunction f, String classes, Object deflt, String how,
204+
@Cached("createIndexSlot(frame)") FrameSlot indexSlot,
205+
@Cached("createVectorSlot(frame)") FrameSlot vectorSlot,
206+
@Cached("create()") RLengthNode lengthNode,
207+
@Cached("create()") UnaryCopyAttributesNode attri,
208+
@Cached("createCountingProfile()") LoopConditionProfile loop,
209+
@Cached("createCallNode(vectorSlot, indexSlot)") RCallBaseNode callNode) {
210+
211+
int length = lengthNode.executeInteger(object);
212+
Object[] result = new Object[length];
213+
FrameSlotChangeMonitor.setObject(frame, vectorSlot, object);
214+
215+
if (length > 0) {
216+
reportWork(this, length);
217+
loop.profileCounted(length);
218+
for (int i = 0; loop.inject(i < length); i++) {
219+
frame.setInt(indexSlot, i + 1);
220+
Object element = object.getDataAt(i);
221+
if (element instanceof RAbstractListVector) {
222+
RList newlist = RDataFactory.createList((Object[]) getRapply().execute(frame, (RAbstractListVector) element, f, classes, deflt, how));
223+
attri.execute(newlist, (RAbstractListVector) element);
224+
result[i] = newlist;
225+
FrameSlotChangeMonitor.setObject(frame, vectorSlot, object);
226+
} else if (isRNull(element)) {
227+
result[i] = RDataFactory.createList();
228+
} else if (classes.equals("ANY") || inheritsNode.execute(element, RDataFactory.createStringVector(classes), false).equals(RRuntime.LOGICAL_TRUE)) {
229+
result[i] = callNode.execute(frame, f);
230+
} else {
231+
result[i] = deflt;
232+
}
233+
}
234+
}
235+
return result;
236+
}
237+
238+
protected RCallBaseNode createCallNode(FrameSlot vectorSlot, FrameSlot indexSlot) {
239+
CompilerAsserts.neverPartOfCompilation();
240+
241+
ExtractElementInternal element = new ExtractElementInternal(vectorSlot, indexSlot);
242+
RSyntaxNode readArgs = ReadVariableNode.wrap(RSyntaxNode.LAZY_DEPARSE, ReadVariableNode.createSilent(ArgumentsSignature.VARARG_NAME, RType.Any));
243+
RNode function = RContext.getASTBuilder().lookup(RSyntaxNode.LAZY_DEPARSE, "f", false).asRNode();
244+
245+
return RCallNode.createCall(createCallSourceSection(), function, ArgumentsSignature.get(null, "..."), element, readArgs);
246+
}
247+
248+
protected static boolean isReplace(String how) {
249+
return how.equals("replace");
250+
}
251+
252+
public static RapplyInternalNode create() {
253+
return RapplyInternalNodeGen.create();
254+
}
255+
}
256+
257+
public static Rapply create() {
258+
return RapplyNodeGen.create();
259+
}
260+
}

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/helpers/InheritsCheckNode.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import com.oracle.truffle.api.nodes.Node;
2929
import com.oracle.truffle.api.profiles.ConditionProfile;
3030
import com.oracle.truffle.r.nodes.function.ClassHierarchyNode;
31-
import com.oracle.truffle.r.nodes.function.ClassHierarchyNodeGen;
3231
import com.oracle.truffle.r.runtime.RRuntime;
3332
import com.oracle.truffle.r.runtime.RType;
3433
import com.oracle.truffle.r.runtime.data.RMissing;
@@ -40,18 +39,38 @@
4039
*/
4140
public final class InheritsCheckNode extends Node {
4241

43-
@Child private ClassHierarchyNode classHierarchy = ClassHierarchyNodeGen.create(false, false);
42+
@Child private ClassHierarchyNode classHierarchy;
4443
private final ConditionProfile nullClassProfile = ConditionProfile.createBinaryProfile();
4544
@CompilationFinal private ConditionProfile exactMatchProfile;
4645
private final String checkedClazz;
4746

48-
public InheritsCheckNode(String checkedClazz) {
47+
private InheritsCheckNode(String checkedClazz) {
4948
this.checkedClazz = checkedClazz;
5049
assert RType.fromMode(checkedClazz) == null : "Class '" + checkedClazz + "' cannot be checked by InheritsCheckNode";
5150
}
5251

52+
private InheritsCheckNode(String checkedClazz, boolean withImplicit) {
53+
if (withImplicit) {
54+
classHierarchy = ClassHierarchyNode.createWithImplicit();
55+
} else {
56+
classHierarchy = ClassHierarchyNode.create();
57+
}
58+
this.checkedClazz = checkedClazz;
59+
if (!withImplicit) {
60+
assert RType.fromMode(checkedClazz) == null : "Class '" + checkedClazz + "' cannot be checked by InheritsCheckNode";
61+
}
62+
}
63+
64+
public static InheritsCheckNode create(String checkedClazz) {
65+
return new InheritsCheckNode(checkedClazz, false);
66+
}
67+
68+
public static InheritsCheckNode createWithImplicit(String checkedClazz) {
69+
return new InheritsCheckNode(checkedClazz, true);
70+
}
71+
5372
public static InheritsCheckNode createFactor() {
54-
return new InheritsCheckNode(RRuntime.CLASS_FACTOR);
73+
return new InheritsCheckNode(RRuntime.CLASS_FACTOR, false);
5574
}
5675

5776
public boolean execute(Object value) {

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/objects/DispatchGeneric.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ protected boolean equalClasses(RStringVector classes, RStringVector cachedClasse
134134
private InheritsCheckNode getInheritsInternalDispatchCheckNode() {
135135
if (inheritsInternalDispatchCheckNode == null) {
136136
CompilerDirectives.transferToInterpreterAndInvalidate();
137-
inheritsInternalDispatchCheckNode = insert(new InheritsCheckNode("internalDispatchMethod"));
137+
inheritsInternalDispatchCheckNode = insert(InheritsCheckNode.create("internalDispatchMethod"));
138138
}
139139
return inheritsInternalDispatchCheckNode;
140140
}

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastLogicalNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ public static CastLogicalNode createNonPreserving() {
211211
protected boolean isFactor(RAbstractIntVector o) {
212212
if (inheritsFactorCheck == null) {
213213
CompilerDirectives.transferToInterpreterAndInvalidate();
214-
inheritsFactorCheck = insert(new InheritsCheckNode(RRuntime.CLASS_FACTOR));
214+
inheritsFactorCheck = insert(InheritsCheckNode.create(RRuntime.CLASS_FACTOR));
215215
}
216216
return inheritsFactorCheck.execute(o);
217217
}

com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/IsFactorNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public boolean executeIsFactor(Object x) {
4343
}
4444
if (inheritsCheck == null) {
4545
CompilerDirectives.transferToInterpreterAndInvalidate();
46-
inheritsCheck = insert(new InheritsCheckNode(RRuntime.CLASS_FACTOR));
46+
inheritsCheck = insert(InheritsCheckNode.create(RRuntime.CLASS_FACTOR));
4747
}
4848

4949
return inheritsCheck.execute(x);

0 commit comments

Comments
 (0)