Skip to content

Commit e64ffef

Browse files
authored
SmallMatrix: Support 1-based indexing (#4188)
To avoid confusion, operations involving both 0 and 1-base indexing are not allowed.
1 parent fcc5bd2 commit e64ffef

File tree

2 files changed

+224
-62
lines changed

2 files changed

+224
-62
lines changed

Src/Base/AMReX_SmallMatrix.H

Lines changed: 93 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,24 @@ namespace amrex {
2020
/**
2121
* \brief Matrix class with compile-time size
2222
*
23-
* The starting index for both rows and columns is always zero. Also
24-
* note that column vectors and row vectors are special cases of a
23+
* Note that column vectors and row vectors are special cases of a
2524
* Matrix.
2625
*
2726
* \tparam T Matrix element data type.
2827
* \tparam NRows Number of rows.
2928
* \tparam NCols Number of columns.
3029
* \tparam ORDER Memory layout order. Order::F (i.e., column-major) by default.
30+
* \tparam StartIndex Starting index. Either 0 or 1.
3131
*/
32-
template <class T, int NRows, int NCols, Order ORDER = Order::F>
32+
template <class T, int NRows, int NCols, Order ORDER = Order::F, int StartIndex = 0>
3333
struct SmallMatrix
3434
{
3535
using value_type = T;
3636
using reference_type = T&;
3737
static constexpr int row_size = NRows;
3838
static constexpr int column_size = NCols;
3939
static constexpr Order ordering = ORDER;
40+
static constexpr int starting_index = StartIndex;
4041

4142
/**
4243
* \brief Default constructor
@@ -78,10 +79,10 @@ namespace amrex {
7879
explicit SmallMatrix (std::initializer_list<std::initializer_list<T>> const& init)
7980
{
8081
AMREX_ASSERT(NRows == init.size());
81-
int i = 0;
82+
int i = StartIndex;
8283
for (auto const& row : init) {
8384
AMREX_ASSERT(NCols == row.size());
84-
int j = 0;
85+
int j = StartIndex;
8586
for (auto const& x : row) {
8687
(*this)(i,j) = x;
8788
++j;
@@ -93,6 +94,11 @@ namespace amrex {
9394
//! Returns a const reference to the element at row i and column j.
9495
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
9596
const T& operator() (int i, int j) const noexcept {
97+
static_assert(StartIndex == 0 || StartIndex == 1);
98+
if constexpr (StartIndex == 1) {
99+
--i;
100+
--j;
101+
}
96102
AMREX_ASSERT(i < NRows && j < NCols);
97103
if constexpr (ORDER == Order::F) {
98104
return m_mat[i+j*NRows];
@@ -104,6 +110,11 @@ namespace amrex {
104110
//! Returns a reference to the element at row i and column j.
105111
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
106112
T& operator() (int i, int j) noexcept {
113+
static_assert(StartIndex == 0 || StartIndex == 1);
114+
if constexpr (StartIndex == 1) {
115+
--i;
116+
--j;
117+
}
107118
AMREX_ASSERT(i < NRows && j < NCols);
108119
if constexpr (ORDER == Order::F) {
109120
return m_mat[i+j*NRows];
@@ -116,6 +127,10 @@ namespace amrex {
116127
template <int MM=NRows, int NN=NCols, std::enable_if_t<(MM==1 || NN==1), int> = 0>
117128
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
118129
const T& operator() (int i) const noexcept {
130+
static_assert(StartIndex == 0 || StartIndex == 1);
131+
if constexpr (StartIndex == 1) {
132+
--i;
133+
}
119134
AMREX_ASSERT(i < NRows*NCols);
120135
return m_mat[i];
121136
}
@@ -124,6 +139,10 @@ namespace amrex {
124139
template <int MM=NRows, int NN=NCols, std::enable_if_t<(MM==1 || NN==1), int> = 0>
125140
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
126141
T& operator() (int i) noexcept {
142+
static_assert(StartIndex == 0 || StartIndex == 1);
143+
if constexpr (StartIndex == 1) {
144+
--i;
145+
}
127146
AMREX_ASSERT(i < NRows*NCols);
128147
return m_mat[i];
129148
}
@@ -132,6 +151,10 @@ namespace amrex {
132151
template <int MM=NRows, int NN=NCols, std::enable_if_t<(MM==1 || NN==1), int> = 0>
133152
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
134153
const T& operator[] (int i) const noexcept {
154+
static_assert(StartIndex == 0 || StartIndex == 1);
155+
if constexpr (StartIndex == 1) {
156+
--i;
157+
}
135158
AMREX_ASSERT(i < NRows*NCols);
136159
return m_mat[i];
137160
}
@@ -140,6 +163,10 @@ namespace amrex {
140163
template <int MM=NRows, int NN=NCols, std::enable_if_t<(MM==1 || NN==1), int> = 0>
141164
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
142165
T& operator[] (int i) noexcept {
166+
static_assert(StartIndex == 0 || StartIndex == 1);
167+
if constexpr (StartIndex == 1) {
168+
--i;
169+
}
143170
AMREX_ASSERT(i < NRows*NCols);
144171
return m_mat[i];
145172
}
@@ -174,7 +201,7 @@ namespace amrex {
174201

175202
//! Set all elements in the matrix to the given value
176203
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
177-
SmallMatrix<T,NRows,NCols,ORDER>&
204+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>&
178205
setVal (T val)
179206
{
180207
for (auto& x : m_mat) { x = val; }
@@ -185,30 +212,32 @@ namespace amrex {
185212
template <int MM=NRows, int NN=NCols, std::enable_if_t<MM==NN, int> = 0>
186213
static constexpr
187214
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
188-
SmallMatrix<T,NRows,NCols,ORDER>
215+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
189216
Identity () noexcept {
190-
SmallMatrix<T,NRows,NCols,ORDER> I{};
191-
constexpr_for<0,NRows>([&] (int i) { I(i,i) = T(1); });
217+
static_assert(StartIndex == 0 || StartIndex == 1);
218+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex> I{};
219+
constexpr_for<StartIndex,NRows+StartIndex>(
220+
[&] (int i) { I(i,i) = T(1); });
192221
return I;
193222
}
194223

195224
//! Returns a matrix initialized with zeros
196225
static constexpr
197226
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
198-
SmallMatrix<T,NRows,NCols,ORDER>
227+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
199228
Zero () noexcept {
200-
SmallMatrix<T,NRows,NCols,ORDER> Z{};
229+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex> Z{};
201230
return Z;
202231
}
203232

204233
//! Returns transposed matrix
205234
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
206-
SmallMatrix<T,NCols,NRows,ORDER>
235+
SmallMatrix<T,NCols,NRows,ORDER,StartIndex>
207236
transpose () const
208237
{
209-
SmallMatrix<T,NCols,NRows,ORDER> r;
210-
for (int j = 0; j < NRows; ++j) {
211-
for (int i = 0; i < NCols; ++i) {
238+
SmallMatrix<T,NCols,NRows,ORDER,StartIndex> r;
239+
for (int j = StartIndex; j < NRows+StartIndex; ++j) {
240+
for (int i = StartIndex; i < NCols+StartIndex; ++i) {
212241
r(i,j) = (*this)(j,i);
213242
}
214243
}
@@ -218,11 +247,12 @@ namespace amrex {
218247
//! Transposes a square matrix in-place.
219248
template <int MM=NRows, int NN=NCols, std::enable_if_t<MM==NN,int> = 0>
220249
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
221-
SmallMatrix<T,NRows,NCols,ORDER>&
250+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>&
222251
transposeInPlace ()
223252
{
224-
for (int j = 1; j < NCols; ++j) {
225-
for (int i = 0; i < j; ++i) {
253+
static_assert(StartIndex == 0 || StartIndex == 1);
254+
for (int j = 1+StartIndex; j < NCols+StartIndex; ++j) {
255+
for (int i = StartIndex; i < j; ++i) {
226256
amrex::Swap((*this)(i,j), (*this)(j,i));
227257
}
228258
}
@@ -257,14 +287,14 @@ namespace amrex {
257287
T trace () const
258288
{
259289
T t = 0;
260-
constexpr_for<0,MM>([&] (int i) { t += (*this)(i,i); });
290+
constexpr_for<StartIndex,MM+StartIndex>([&] (int i) { t += (*this)(i,i); });
261291
return t;
262292
}
263293

264294
//! Operator += performing matrix addition as in (*this) += rhs
265295
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
266-
SmallMatrix<T,NRows,NCols,ORDER>&
267-
operator += (SmallMatrix<T,NRows,NCols,ORDER> const& rhs)
296+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>&
297+
operator += (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> const& rhs)
268298
{
269299
for (int n = 0; n < NRows*NCols; ++n) {
270300
m_mat[n] += rhs.m_mat[n];
@@ -274,18 +304,18 @@ namespace amrex {
274304

275305
//! Binary operator + returning the result of maxtrix addition, lhs+rhs
276306
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
277-
friend SmallMatrix<T,NRows,NCols,ORDER>
278-
operator+ (SmallMatrix<T,NRows,NCols,ORDER> lhs,
279-
SmallMatrix<T,NRows,NCols,ORDER> const& rhs)
307+
friend SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
308+
operator+ (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> lhs,
309+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex> const& rhs)
280310
{
281311
lhs += rhs;
282312
return lhs;
283313
}
284314

285315
//! Operator -= performing matrix subtraction as in (*this) -= rhs
286316
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
287-
SmallMatrix<T,NRows,NCols,ORDER>&
288-
operator -= (SmallMatrix<T,NRows,NCols,ORDER> const& rhs)
317+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>&
318+
operator -= (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> const& rhs)
289319
{
290320
for (int n = 0; n < NRows*NCols; ++n) {
291321
m_mat[n] -= rhs.m_mat[n];
@@ -295,25 +325,25 @@ namespace amrex {
295325

296326
//! Binary operator - returning the result of maxtrix subtraction, lhs-rhs
297327
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
298-
friend SmallMatrix<T,NRows,NCols,ORDER>
299-
operator- (SmallMatrix<T,NRows,NCols,ORDER> lhs,
300-
SmallMatrix<T,NRows,NCols,ORDER> const& rhs)
328+
friend SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
329+
operator- (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> lhs,
330+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex> const& rhs)
301331
{
302332
lhs -= rhs;
303333
return lhs;
304334
}
305335

306336
//! Unary minus operator
307337
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
308-
SmallMatrix<T,NRows,NCols,ORDER>
338+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
309339
operator- () const
310340
{
311341
return (*this) * T(-1);
312342
}
313343

314344
//! Operator *= that scales this matrix in place by a scalar.
315345
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
316-
SmallMatrix<T,NRows,NCols,ORDER>&
346+
SmallMatrix<T,NRows,NCols,ORDER,StartIndex>&
317347
operator *= (T a)
318348
{
319349
for (auto& x : m_mat) {
@@ -324,32 +354,32 @@ namespace amrex {
324354

325355
//! Returns the product of a matrix and a scalar
326356
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
327-
friend SmallMatrix<T,NRows,NCols,ORDER>
328-
operator* (SmallMatrix<T,NRows,NCols,ORDER> m, T a)
357+
friend SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
358+
operator* (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> m, T a)
329359
{
330360
m *= a;
331361
return m;
332362
}
333363

334364
//! Returns the product of a scalar and a matrix
335365
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
336-
friend SmallMatrix<T,NRows,NCols,ORDER>
337-
operator* (T a, SmallMatrix<T,NRows,NCols,ORDER> m)
366+
friend SmallMatrix<T,NRows,NCols,ORDER,StartIndex>
367+
operator* (T a, SmallMatrix<T,NRows,NCols,ORDER,StartIndex> m)
338368
{
339369
m *= a;
340370
return m;
341371
}
342372

343373
//! Returns matrix product of two matrices
344-
template <class U, int N1, int N2, int N3, Order Ord>
374+
template <class U, int N1, int N2, int N3, Order Ord, int SI>
345375
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
346-
friend SmallMatrix<U,N1,N3,Ord>
347-
operator* (SmallMatrix<U,N1,N2,Ord> const& lhs,
348-
SmallMatrix<U,N2,N3,Ord> const& rhs);
376+
friend SmallMatrix<U,N1,N3,Ord,SI>
377+
operator* (SmallMatrix<U,N1,N2,Ord,SI> const& lhs,
378+
SmallMatrix<U,N2,N3,Ord,SI> const& rhs);
349379

350380
//! Returns the dot product of two vectors
351381
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
352-
T dot (SmallMatrix<T,NRows,NCols,ORDER> const& rhs) const
382+
T dot (SmallMatrix<T,NRows,NCols,ORDER,StartIndex> const& rhs) const
353383
{
354384
T r = 0;
355385
for (int n = 0; n < NRows*NCols; ++n) {
@@ -362,30 +392,31 @@ namespace amrex {
362392
T m_mat[NRows*NCols];
363393
};
364394

365-
template <class U, int N1, int N2, int N3, Order Ord>
395+
template <class U, int N1, int N2, int N3, Order Ord, int SI>
366396
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
367-
SmallMatrix<U,N1,N3,Ord>
368-
operator* (SmallMatrix<U,N1,N2,Ord> const& lhs,
369-
SmallMatrix<U,N2,N3,Ord> const& rhs)
397+
SmallMatrix<U,N1,N3,Ord,SI>
398+
operator* (SmallMatrix<U,N1,N2,Ord,SI> const& lhs,
399+
SmallMatrix<U,N2,N3,Ord,SI> const& rhs)
370400
{
371-
SmallMatrix<U,N1,N3,Ord> r;
401+
static_assert(SI == 0 || SI == 1);
402+
SmallMatrix<U,N1,N3,Ord,SI> r;
372403
if constexpr (Ord == Order::F) {
373-
for (int j = 0; j < N3; ++j) {
374-
constexpr_for<0,N1>([&] (int i) { r(i,j) = U(0); });
375-
for (int k = 0; k < N2; ++k) {
404+
for (int j = SI; j < N3+SI; ++j) {
405+
constexpr_for<SI,N1+SI>([&] (int i) { r(i,j) = U(0); });
406+
for (int k = SI; k < N2+SI; ++k) {
376407
auto b = rhs(k,j);
377-
constexpr_for<0,N1>([&] (int i)
408+
constexpr_for<SI,N1+SI>([&] (int i)
378409
{
379410
r(i,j) += lhs(i,k) * b;
380411
});
381412
}
382413
}
383414
} else {
384-
for (int i = 0; i < N1; ++i) {
385-
constexpr_for<0,N3>([&] (int j) { r(i,j) = U(0); });
386-
for (int k = 0; k < N2; ++k) {
415+
for (int i = SI; i < N1+SI; ++i) {
416+
constexpr_for<SI,N3+SI>([&] (int j) { r(i,j) = U(0); });
417+
for (int k = SI; k < N2+SI; ++k) {
387418
auto a = lhs(i,k);
388-
constexpr_for<0,N3>([&] (int j)
419+
constexpr_for<SI,N3+SI>([&] (int j)
389420
{
390421
r(i,j) += a * rhs(k,j);
391422
});
@@ -395,25 +426,25 @@ namespace amrex {
395426
return r;
396427
}
397428

398-
template <class T, int NRows, int NCols, Order ORDER>
429+
template <class T, int NRows, int NCols, Order ORDER, int SI>
399430
std::ostream& operator<< (std::ostream& os,
400-
SmallMatrix<T,NRows,NCols,ORDER> const& mat)
431+
SmallMatrix<T,NRows,NCols,ORDER,SI> const& mat)
401432
{
402-
for (int i = 0; i < NRows; ++i) {
403-
os << mat(i,0);
404-
for (int j = 1; j < NCols; ++j) {
433+
for (int i = SI; i < NRows+SI; ++i) {
434+
os << mat(i,SI);
435+
for (int j = 1+SI; j < NCols+SI; ++j) {
405436
os << " " << mat(i,j);
406437
}
407438
os << "\n";
408439
}
409440
return os;
410441
}
411442

412-
template <class T, int N>
413-
using SmallVector = SmallMatrix<T,N,1>;
443+
template <class T, int N, int StartIndex = 0>
444+
using SmallVector = SmallMatrix<T,N,1,Order::F,StartIndex>;
414445

415-
template <class T, int N>
416-
using SmallRowVector = SmallMatrix<T,1,N>;
446+
template <class T, int N, int StartIndex = 0>
447+
using SmallRowVector = SmallMatrix<T,1,N,Order::F,StartIndex>;
417448
}
418449

419450
#endif

0 commit comments

Comments
 (0)