2424import static com .oracle .truffle .r .runtime .builtins .RBuiltinKind .INTERNAL ;
2525
2626import com .oracle .truffle .api .dsl .Specialization ;
27+ import com .oracle .truffle .api .library .CachedLibrary ;
2728import com .oracle .truffle .r .runtime .RRuntime ;
2829import com .oracle .truffle .r .runtime .builtins .RBuiltin ;
2930import com .oracle .truffle .r .runtime .data .RDataFactory ;
3031import com .oracle .truffle .r .runtime .data .RDoubleVector ;
3132import com .oracle .truffle .r .runtime .data .RIntVector ;
3233import com .oracle .truffle .r .runtime .data .RLogicalVector ;
34+ import com .oracle .truffle .r .runtime .data .VectorDataLibrary ;
35+ import com .oracle .truffle .r .runtime .data .VectorDataLibrary .RandomAccessIterator ;
3336import com .oracle .truffle .r .runtime .ops .BinaryArithmetic ;
37+ import com .oracle .truffle .r .runtime .ops .na .NACheck ;
3438
3539//Implements .colMeans
3640@ RBuiltin (name = "colMeans" , kind = INTERNAL , parameterNames = {"X" , "m" , "n" , "na.rm" }, behavior = PURE )
@@ -42,17 +46,20 @@ public abstract class ColMeans extends ColSumsBase {
4246 createCasts (ColMeans .class );
4347 }
4448
45- @ Specialization (guards = "!naRm" )
46- protected RDoubleVector colMeansNaRmFalse (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
47- checkVectorLength (x , rowNum , colNum );
49+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
50+ protected RDoubleVector colMeansNaRmFalse (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
51+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
52+ checkVectorLength (dataLib , x , rowNum , colNum );
4853
4954 double [] result = new double [colNum ];
5055 boolean isComplete = true ;
51- na .enable (x );
56+ Object xData = x .getData ();
57+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
58+ NACheck na = dataLib .getNACheck (xData );
5259 nextCol : for (int c = 0 ; c < colNum ; c ++) {
5360 double sum = 0 ;
5461 for (int i = 0 ; i < rowNum ; i ++) {
55- double el = x . getDataAt ( c * rowNum + i );
62+ double el = dataLib . getDouble ( xData , it , c * rowNum + i );
5663 if (na .check (el )) {
5764 result [c ] = RRuntime .DOUBLE_NA ;
5865 continue nextCol ;
@@ -69,18 +76,21 @@ protected RDoubleVector colMeansNaRmFalse(RDoubleVector x, int rowNum, int colNu
6976 return RDataFactory .createDoubleVector (result , na .neverSeenNA () && isComplete );
7077 }
7178
72- @ Specialization (guards = "naRm" )
73- protected RDoubleVector colMeansNaRmTrue (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
74- checkVectorLength (x , rowNum , colNum );
79+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
80+ protected RDoubleVector colMeansNaRmTrue (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
81+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
82+ checkVectorLength (dataLib , x , rowNum , colNum );
7583
7684 double [] result = new double [colNum ];
7785 boolean isComplete = true ;
78- na .enable (x );
86+ Object xData = x .getData ();
87+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
88+ NACheck na = dataLib .getNACheck (xData );
7989 for (int c = 0 ; c < colNum ; c ++) {
8090 double sum = 0 ;
8191 int nonNaNumCount = 0 ;
8292 for (int i = 0 ; i < rowNum ; i ++) {
83- double el = x . getDataAt ( c * rowNum + i );
93+ double el = dataLib . getDouble ( xData , it , c * rowNum + i );
8494 if (!na .check (el ) && !Double .isNaN (el )) {
8595 sum = add .op (sum , el );
8696 nonNaNumCount ++;
@@ -96,16 +106,19 @@ protected RDoubleVector colMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum
96106 return RDataFactory .createDoubleVector (result , isComplete );
97107 }
98108
99- @ Specialization (guards = "!naRm" )
100- protected RDoubleVector colMeansNaRmFalse (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
101- checkVectorLength (x , rowNum , colNum );
109+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
110+ protected RDoubleVector colMeansNaRmFalse (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
111+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
112+ checkVectorLength (dataLib , x , rowNum , colNum );
102113
103114 double [] result = new double [colNum ];
104- na .enable (x );
115+ Object xData = x .getData ();
116+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
117+ NACheck na = dataLib .getNACheck (xData );
105118 nextCol : for (int c = 0 ; c < colNum ; c ++) {
106119 double sum = 0 ;
107120 for (int i = 0 ; i < rowNum ; i ++) {
108- byte el = x . getDataAt ( c * rowNum + i );
121+ byte el = dataLib . getLogical ( xData , it , c * rowNum + i );
109122 if (na .check (el )) {
110123 result [c ] = RRuntime .DOUBLE_NA ;
111124 continue nextCol ;
@@ -117,18 +130,21 @@ protected RDoubleVector colMeansNaRmFalse(RLogicalVector x, int rowNum, int colN
117130 return RDataFactory .createDoubleVector (result , na .neverSeenNA ());
118131 }
119132
120- @ Specialization (guards = "naRm" )
121- protected RDoubleVector colMeansNaRmTrue (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
122- checkVectorLength (x , rowNum , colNum );
133+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
134+ protected RDoubleVector colMeansNaRmTrue (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
135+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
136+ checkVectorLength (dataLib , x , rowNum , colNum );
123137
124138 double [] result = new double [colNum ];
125139 boolean isComplete = true ;
126- na .enable (x );
140+ Object xData = x .getData ();
141+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
142+ NACheck na = dataLib .getNACheck (xData );
127143 for (int c = 0 ; c < colNum ; c ++) {
128144 double sum = 0 ;
129145 int nonNaNumCount = 0 ;
130146 for (int i = 0 ; i < rowNum ; i ++) {
131- byte el = x . getDataAt ( c * rowNum + i );
147+ byte el = dataLib . getLogical ( xData , it , c * rowNum + i );
132148 if (!na .check (el )) {
133149 sum = add .op (sum , el );
134150 nonNaNumCount ++;
@@ -144,16 +160,19 @@ protected RDoubleVector colMeansNaRmTrue(RLogicalVector x, int rowNum, int colNu
144160 return RDataFactory .createDoubleVector (result , isComplete );
145161 }
146162
147- @ Specialization (guards = "!naRm" )
148- protected RDoubleVector colMeansNaRmFalse (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
149- checkVectorLength (x , rowNum , colNum );
163+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
164+ protected RDoubleVector colMeansNaRmFalse (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
165+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
166+ checkVectorLength (dataLib , x , rowNum , colNum );
150167
151168 double [] result = new double [colNum ];
152- na .enable (x );
169+ Object xData = x .getData ();
170+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
171+ NACheck na = dataLib .getNACheck (xData );
153172 nextCol : for (int c = 0 ; c < colNum ; c ++) {
154173 double sum = 0 ;
155174 for (int i = 0 ; i < rowNum ; i ++) {
156- int el = x . getDataAt ( c * rowNum + i );
175+ int el = dataLib . getInt ( xData , it , c * rowNum + i );
157176 if (na .check (el )) {
158177 result [c ] = RRuntime .DOUBLE_NA ;
159178 continue nextCol ;
@@ -165,18 +184,21 @@ protected RDoubleVector colMeansNaRmFalse(RIntVector x, int rowNum, int colNum,
165184 return RDataFactory .createDoubleVector (result , na .neverSeenNA ());
166185 }
167186
168- @ Specialization (guards = "naRm" )
169- protected RDoubleVector colMeansNaRmTrue (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
170- checkVectorLength (x , rowNum , colNum );
187+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
188+ protected RDoubleVector colMeansNaRmTrue (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
189+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
190+ checkVectorLength (dataLib , x , rowNum , colNum );
171191
172192 double [] result = new double [colNum ];
173193 boolean isComplete = true ;
174- na .enable (x );
194+ Object xData = x .getData ();
195+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
196+ NACheck na = dataLib .getNACheck (xData );
175197 for (int c = 0 ; c < colNum ; c ++) {
176198 double sum = 0 ;
177199 int nonNaNumCount = 0 ;
178200 for (int i = 0 ; i < rowNum ; i ++) {
179- int el = x . getDataAt ( c * rowNum + i );
201+ int el = dataLib . getInt ( xData , it , c * rowNum + i );
180202 if (!na .check (el )) {
181203 sum = add .op (sum , el );
182204 nonNaNumCount ++;
0 commit comments