2424import static com .oracle .truffle .r .runtime .builtins .RBuiltinKind .INTERNAL ;
2525
2626import com .oracle .truffle .api .dsl .Specialization ;
27+ import com .oracle .truffle .api .library .CachedLibrary ;
2728import com .oracle .truffle .r .runtime .RRuntime ;
2829import com .oracle .truffle .r .runtime .builtins .RBuiltin ;
2930import com .oracle .truffle .r .runtime .data .RDataFactory ;
3031import com .oracle .truffle .r .runtime .data .RDoubleVector ;
3132import com .oracle .truffle .r .runtime .data .RIntVector ;
3233import com .oracle .truffle .r .runtime .data .RLogicalVector ;
34+ import com .oracle .truffle .r .runtime .data .VectorDataLibrary ;
35+ import com .oracle .truffle .r .runtime .data .VectorDataLibrary .RandomAccessIterator ;
3336import com .oracle .truffle .r .runtime .ops .BinaryArithmetic ;
3437import com .oracle .truffle .r .runtime .ops .na .NACheck ;
3538
3639//Implements .colMeans
3740@ RBuiltin (name = "colMeans" , kind = INTERNAL , parameterNames = {"X" , "m" , "n" , "na.rm" }, behavior = PURE )
3841public abstract class ColMeans extends ColSumsBase {
3942
40- protected final NACheck na = NACheck .create ();
4143 @ Child private BinaryArithmetic add = BinaryArithmetic .ADD .createOperation ();
4244
4345 static {
4446 createCasts (ColMeans .class );
4547 }
4648
47- @ Specialization (guards = "!naRm" )
48- protected RDoubleVector colMeansNaRmFalse (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
49- checkVectorLength (x , rowNum , colNum );
49+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
50+ protected RDoubleVector colMeansNaRmFalse (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
51+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
52+ checkVectorLength (dataLib , x , rowNum , colNum );
5053
5154 double [] result = new double [colNum ];
5255 boolean isComplete = true ;
53- na .enable (x );
56+ Object xData = x .getData ();
57+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
58+ NACheck na = dataLib .getNACheck (xData );
5459 nextCol : for (int c = 0 ; c < colNum ; c ++) {
5560 double sum = 0 ;
5661 for (int i = 0 ; i < rowNum ; i ++) {
57- double el = x . getDataAt ( c * rowNum + i );
62+ double el = dataLib . getDouble ( xData , it , c * rowNum + i );
5863 if (na .check (el )) {
5964 result [c ] = RRuntime .DOUBLE_NA ;
6065 continue nextCol ;
@@ -71,18 +76,21 @@ protected RDoubleVector colMeansNaRmFalse(RDoubleVector x, int rowNum, int colNu
7176 return RDataFactory .createDoubleVector (result , na .neverSeenNA () && isComplete );
7277 }
7378
74- @ Specialization (guards = "naRm" )
75- protected RDoubleVector colMeansNaRmTrue (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
76- checkVectorLength (x , rowNum , colNum );
79+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
80+ protected RDoubleVector colMeansNaRmTrue (RDoubleVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
81+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
82+ checkVectorLength (dataLib , x , rowNum , colNum );
7783
7884 double [] result = new double [colNum ];
7985 boolean isComplete = true ;
80- na .enable (x );
86+ Object xData = x .getData ();
87+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
88+ NACheck na = dataLib .getNACheck (xData );
8189 for (int c = 0 ; c < colNum ; c ++) {
8290 double sum = 0 ;
8391 int nonNaNumCount = 0 ;
8492 for (int i = 0 ; i < rowNum ; i ++) {
85- double el = x . getDataAt ( c * rowNum + i );
93+ double el = dataLib . getDouble ( xData , it , c * rowNum + i );
8694 if (!na .check (el ) && !Double .isNaN (el )) {
8795 sum = add .op (sum , el );
8896 nonNaNumCount ++;
@@ -98,16 +106,19 @@ protected RDoubleVector colMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum
98106 return RDataFactory .createDoubleVector (result , isComplete );
99107 }
100108
101- @ Specialization (guards = "!naRm" )
102- protected RDoubleVector colMeansNaRmFalse (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
103- checkVectorLength (x , rowNum , colNum );
109+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
110+ protected RDoubleVector colMeansNaRmFalse (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
111+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
112+ checkVectorLength (dataLib , x , rowNum , colNum );
104113
105114 double [] result = new double [colNum ];
106- na .enable (x );
115+ Object xData = x .getData ();
116+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
117+ NACheck na = dataLib .getNACheck (xData );
107118 nextCol : for (int c = 0 ; c < colNum ; c ++) {
108119 double sum = 0 ;
109120 for (int i = 0 ; i < rowNum ; i ++) {
110- byte el = x . getDataAt ( c * rowNum + i );
121+ byte el = dataLib . getLogical ( xData , it , c * rowNum + i );
111122 if (na .check (el )) {
112123 result [c ] = RRuntime .DOUBLE_NA ;
113124 continue nextCol ;
@@ -119,18 +130,21 @@ protected RDoubleVector colMeansNaRmFalse(RLogicalVector x, int rowNum, int colN
119130 return RDataFactory .createDoubleVector (result , na .neverSeenNA ());
120131 }
121132
122- @ Specialization (guards = "naRm" )
123- protected RDoubleVector colMeansNaRmTrue (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
124- checkVectorLength (x , rowNum , colNum );
133+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
134+ protected RDoubleVector colMeansNaRmTrue (RLogicalVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
135+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
136+ checkVectorLength (dataLib , x , rowNum , colNum );
125137
126138 double [] result = new double [colNum ];
127139 boolean isComplete = true ;
128- na .enable (x );
140+ Object xData = x .getData ();
141+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
142+ NACheck na = dataLib .getNACheck (xData );
129143 for (int c = 0 ; c < colNum ; c ++) {
130144 double sum = 0 ;
131145 int nonNaNumCount = 0 ;
132146 for (int i = 0 ; i < rowNum ; i ++) {
133- byte el = x . getDataAt ( c * rowNum + i );
147+ byte el = dataLib . getLogical ( xData , it , c * rowNum + i );
134148 if (!na .check (el )) {
135149 sum = add .op (sum , el );
136150 nonNaNumCount ++;
@@ -146,16 +160,19 @@ protected RDoubleVector colMeansNaRmTrue(RLogicalVector x, int rowNum, int colNu
146160 return RDataFactory .createDoubleVector (result , isComplete );
147161 }
148162
149- @ Specialization (guards = "!naRm" )
150- protected RDoubleVector colMeansNaRmFalse (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
151- checkVectorLength (x , rowNum , colNum );
163+ @ Specialization (guards = "!naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
164+ protected RDoubleVector colMeansNaRmFalse (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
165+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
166+ checkVectorLength (dataLib , x , rowNum , colNum );
152167
153168 double [] result = new double [colNum ];
154- na .enable (x );
169+ Object xData = x .getData ();
170+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
171+ NACheck na = dataLib .getNACheck (xData );
155172 nextCol : for (int c = 0 ; c < colNum ; c ++) {
156173 double sum = 0 ;
157174 for (int i = 0 ; i < rowNum ; i ++) {
158- int el = x . getDataAt ( c * rowNum + i );
175+ int el = dataLib . getInt ( xData , it , c * rowNum + i );
159176 if (na .check (el )) {
160177 result [c ] = RRuntime .DOUBLE_NA ;
161178 continue nextCol ;
@@ -167,18 +184,21 @@ protected RDoubleVector colMeansNaRmFalse(RIntVector x, int rowNum, int colNum,
167184 return RDataFactory .createDoubleVector (result , na .neverSeenNA ());
168185 }
169186
170- @ Specialization (guards = "naRm" )
171- protected RDoubleVector colMeansNaRmTrue (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ) {
172- checkVectorLength (x , rowNum , colNum );
187+ @ Specialization (guards = "naRm" , limit = "getTypedVectorDataLibraryCacheSize()" )
188+ protected RDoubleVector colMeansNaRmTrue (RIntVector x , int rowNum , int colNum , @ SuppressWarnings ("unused" ) boolean naRm ,
189+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary dataLib ) {
190+ checkVectorLength (dataLib , x , rowNum , colNum );
173191
174192 double [] result = new double [colNum ];
175193 boolean isComplete = true ;
176- na .enable (x );
194+ Object xData = x .getData ();
195+ RandomAccessIterator it = dataLib .randomAccessIterator (xData );
196+ NACheck na = dataLib .getNACheck (xData );
177197 for (int c = 0 ; c < colNum ; c ++) {
178198 double sum = 0 ;
179199 int nonNaNumCount = 0 ;
180200 for (int i = 0 ; i < rowNum ; i ++) {
181- int el = x . getDataAt ( c * rowNum + i );
201+ int el = dataLib . getInt ( xData , it , c * rowNum + i );
182202 if (!na .check (el )) {
183203 sum = add .op (sum , el );
184204 nonNaNumCount ++;
0 commit comments