Skip to content

Commit 4dac084

Browse files
committed
[GR-2737] Minor performance improvements in some regex builtins.
PullRequest: fastr/2056
2 parents b90a06e + 81b5780 commit 4dac084

File tree

3 files changed

+62
-22
lines changed

3 files changed

+62
-22
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GrepFunctions.java

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@
3535
import java.util.regex.Pattern;
3636
import java.util.regex.PatternSyntaxException;
3737

38+
import com.oracle.truffle.api.CompilerDirectives;
3839
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
3940
import com.oracle.truffle.api.dsl.Cached;
4041
import com.oracle.truffle.api.dsl.Fallback;
4142
import com.oracle.truffle.api.dsl.ImportStatic;
4243
import com.oracle.truffle.api.dsl.Specialization;
4344
import com.oracle.truffle.api.nodes.NodeCost;
4445
import com.oracle.truffle.api.nodes.NodeInfo;
46+
import com.oracle.truffle.api.profiles.LoopConditionProfile;
47+
import com.oracle.truffle.api.profiles.ValueProfile;
4548
import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode;
4649
import com.oracle.truffle.r.nodes.builtin.NodeWithArgumentCasts.Casts;
4750
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
@@ -177,6 +180,7 @@ protected void valueCheck(boolean value) {
177180

178181
protected void checkNotImplemented(boolean condition, String arg, boolean b) {
179182
if (condition) {
183+
CompilerDirectives.transferToInterpreter();
180184
throw RError.nyi(this, arg + " == " + b);
181185
}
182186
}
@@ -1063,10 +1067,16 @@ public Info(int index, int size, int[] captureStart, int[] captureLength, String
10631067
}
10641068

10651069
@Specialization
1066-
@TruffleBoundary
1067-
protected Object regexp(RAbstractStringVector patternArg, RAbstractStringVector vector, boolean ignoreCase, boolean fixed, boolean useBytesL,
1070+
protected Object regexp(RAbstractStringVector patternArgIn, RAbstractStringVector vector, boolean ignoreCaseIn, boolean fixedIn, boolean useBytesL,
1071+
@Cached("createClassProfile()") ValueProfile patternProfile,
1072+
@Cached("createIdentityProfile()") ValueProfile fixedProfile,
1073+
@Cached("createIdentityProfile()") ValueProfile ignoreCaseProfile,
1074+
@Cached("createCountingProfile()") LoopConditionProfile loopConditionProfile,
10681075
@Cached("createCommon()") CommonCodeNode common) {
10691076
try {
1077+
RAbstractStringVector patternArg = patternProfile.profile(patternArgIn);
1078+
boolean fixed = fixedProfile.profile(fixedIn);
1079+
boolean ignoreCase = ignoreCaseProfile.profile(ignoreCaseIn);
10701080
common.checkExtraArgs(false, false, false, useBytesL, false);
10711081
if (patternArg.getLength() > 1) {
10721082
throw RInternalError.unimplemented("multi-element patterns in regexpr not implemented yet");
@@ -1079,20 +1089,21 @@ protected Object regexp(RAbstractStringVector patternArg, RAbstractStringVector
10791089
boolean useBytes = true;
10801090
String indexType = "chars"; // TODO: normally should be: useBytes ? "bytes" :
10811091
// "chars";
1082-
for (int i = 0; i < vector.getLength(); i++) {
1092+
loopConditionProfile.profileCounted(vector.getLength());
1093+
for (int i = 0; loopConditionProfile.inject(i < vector.getLength()); i++) {
10831094
int[] matchPos;
10841095
int[] matchLength;
10851096
if (pattern.length() == 0) {
10861097
// emtpy pattern
10871098
matchPos = new int[]{1};
10881099
matchLength = new int[]{0};
10891100
} else {
1090-
List<Info> res = getInfo(pattern, vector.getDataAt(i), ignoreCase, fixed);
1091-
matchPos = new int[res.size()];
1092-
matchLength = new int[res.size()];
1093-
for (int j = 0; j < res.size(); j++) {
1094-
matchPos[j] = res.get(j).index;
1095-
matchLength[j] = res.get(j).size;
1101+
Info[] res = getInfo(pattern, vector.getDataAt(i), ignoreCase, fixed);
1102+
matchPos = new int[res.length];
1103+
matchLength = new int[res.length];
1104+
for (int j = 0; j < res.length; j++) {
1105+
matchPos[j] = res[j].index;
1106+
matchLength[j] = res[j].size;
10961107
}
10971108
}
10981109
RIntVector matches = RDataFactory.createIntVector(matchPos, RDataFactory.COMPLETE_VECTOR);
@@ -1105,12 +1116,12 @@ protected Object regexp(RAbstractStringVector patternArg, RAbstractStringVector
11051116
}
11061117
return ret;
11071118
} catch (PatternSyntaxException e) {
1108-
throw error(Message.INVALID_REGEXP_REASON, patternArg, e.getMessage());
1119+
throw error(Message.INVALID_REGEXP_REASON, patternArgIn, e.getMessage());
11091120
}
11101121
}
11111122

1112-
protected List<Info> getInfo(String pattern, String text, boolean ignoreCase, boolean fixed) {
1113-
List<Info> list = new ArrayList<>();
1123+
protected Info[] getInfo(String pattern, String text, boolean ignoreCase, boolean fixed) {
1124+
Info[] result = null;
11141125
if (fixed) {
11151126
int index;
11161127
if (ignoreCase) {
@@ -1119,21 +1130,26 @@ protected List<Info> getInfo(String pattern, String text, boolean ignoreCase, bo
11191130
index = text.indexOf(pattern);
11201131
}
11211132
if (index != -1) {
1122-
list.add(new Info(index + 1, pattern.length(), null, null, null));
1133+
result = new Info[]{new Info(index + 1, pattern.length(), null, null, null)};
11231134
}
11241135
} else {
11251136
Matcher m = getPatternMatcher(pattern, text, ignoreCase);
1126-
if (m.find()) {
1137+
if (find(m)) {
1138+
result = new Info[m.groupCount() + 1];
11271139
for (int i = 0; i <= m.groupCount(); i++) {
1128-
list.add(new Info(m.start(i) + 1, m.end(i) - m.start(i), null, null, null));
1140+
result[i] = new Info(m.start(i) + 1, m.end(i) - m.start(i), null, null, null);
11291141
}
11301142
}
11311143
}
1132-
if (list.size() > 0) {
1133-
return list;
1144+
if (result != null) {
1145+
return result;
11341146
}
1135-
list.add(new Info(-1, -1, null, null, null));
1136-
return list;
1147+
return new Info[]{new Info(-1, -1, null, null, null)};
1148+
}
1149+
1150+
@TruffleBoundary
1151+
private static boolean find(Matcher m) {
1152+
return m.find();
11371153
}
11381154

11391155
@TruffleBoundary

com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/PCRERFFI.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
*/
2323
package com.oracle.truffle.r.runtime.ffi;
2424

25+
import java.lang.ref.SoftReference;
2526
import java.nio.charset.StandardCharsets;
27+
import java.util.concurrent.atomic.AtomicReference;
2628

2729
import com.oracle.truffle.api.CompilerDirectives;
2830
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -163,6 +165,18 @@ public static StudyNode create() {
163165
}
164166

165167
public static final class ExecNode extends NativeCallNode {
168+
private static final class BytesCache {
169+
private final String input;
170+
private final byte[] bytes;
171+
172+
BytesCache(String input, byte[] bytes) {
173+
this.input = input;
174+
this.bytes = bytes;
175+
}
176+
}
177+
178+
private AtomicReference<SoftReference<BytesCache>> cachedBytes = new AtomicReference<>();
179+
166180
private ExecNode(DownCallNodeFactory factory) {
167181
super(factory.createDownCallNode(NativeFunction.exec));
168182
}
@@ -174,8 +188,18 @@ public int execute(long code, long extra, String subject, int offset, int option
174188
}
175189

176190
@TruffleBoundary
177-
private static byte[] getBytes(String subject) {
178-
return subject.getBytes(StandardCharsets.UTF_8);
191+
private byte[] getBytes(String subject) {
192+
if (subject.length() <= 32) {
193+
return subject.getBytes(StandardCharsets.UTF_8);
194+
}
195+
SoftReference<BytesCache> cacheRef = cachedBytes.get();
196+
BytesCache cache = cacheRef == null ? null : cacheRef.get();
197+
if (cache != null && cache.input == subject) {
198+
return cache.bytes;
199+
}
200+
byte[] result = subject.getBytes(StandardCharsets.UTF_8);
201+
cachedBytes.set(new SoftReference<>(new BytesCache(subject, result)));
202+
return result;
179203
}
180204

181205
public static ExecNode create() {

mx.fastr/mx_fastr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def do_run_r(args, command, extraVmArgs=None, jdk=None, **kwargs):
9797
vmArgs += _sulong_options()
9898
args = _sulong_args() + args
9999

100-
if extraVmArgs is None or not '-da' in extraVmArgs:
100+
if not "FASTR_NO_ASSERTS" in os.environ and (extraVmArgs is None or not '-da' in extraVmArgs):
101101
# unless explicitly disabled we enable assertion checking
102102
vmArgs += ['-ea', '-esa']
103103

0 commit comments

Comments
 (0)