|
| 1 | +#ifndef AMREX_FFT_OPENBC_SOLVER_H_ |
| 2 | +#define AMREX_FFT_OPENBC_SOLVER_H_ |
| 3 | + |
| 4 | +#include <AMReX_FFT_R2C.H> |
| 5 | + |
| 6 | +#include <AMReX_VisMF.H> |
| 7 | + |
| 8 | +namespace amrex::FFT |
| 9 | +{ |
| 10 | + |
| 11 | +template <typename T = Real> |
| 12 | +class OpenBCSolver |
| 13 | +{ |
| 14 | +public: |
| 15 | + using MF = typename R2C<T>::MF; |
| 16 | + using cMF = typename R2C<T>::cMF; |
| 17 | + |
| 18 | + explicit OpenBCSolver (Box const& domain); |
| 19 | + |
| 20 | + template <class F> |
| 21 | + void setGreensFunction (F const& greens_function); |
| 22 | + |
| 23 | + void solve (MF& phi, MF const& rho); |
| 24 | + |
| 25 | + [[nodiscard]] Box const& Domain () const { return m_domain; } |
| 26 | + |
| 27 | +private: |
| 28 | + Box m_domain; |
| 29 | + R2C<T> m_r2c; |
| 30 | + cMF m_G_fft; |
| 31 | +}; |
| 32 | + |
| 33 | +template <typename T> |
| 34 | +OpenBCSolver<T>::OpenBCSolver (Box const& domain) |
| 35 | + : m_domain(domain), |
| 36 | + m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType())) |
| 37 | +{ |
| 38 | + auto [sd, ord] = m_r2c.getSpectralData(); |
| 39 | + amrex::ignore_unused(ord); |
| 40 | + m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0); |
| 41 | +} |
| 42 | + |
| 43 | +template <typename T> |
| 44 | +template <class F> |
| 45 | +void OpenBCSolver<T>::setGreensFunction (F const& greens_function) |
| 46 | +{ |
| 47 | + auto* infab = detail::get_fab(m_r2c.m_rx); |
| 48 | + auto const& lo = m_domain.smallEnd(); |
| 49 | + auto const& lo3 = lo.dim3(); |
| 50 | + auto const& len = m_domain.length3d(); |
| 51 | + if (infab) { |
| 52 | + auto const& a = infab->array(); |
| 53 | + auto box = infab->box(); |
| 54 | + GpuArray<int,3> nimages{1,1,1}; |
| 55 | + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { |
| 56 | + if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) { |
| 57 | + box.growHi(idim, -len[idim]+1); // +1 to include the middle plane |
| 58 | + nimages[idim] = 2; |
| 59 | + } |
| 60 | + } |
| 61 | + AMREX_ASSERT(nimages[0] == 2); |
| 62 | + box.shift(-lo); |
| 63 | + amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k) |
| 64 | + { |
| 65 | + if (i == len[0] || j == len[1] || k == len[2]) { |
| 66 | + a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0); |
| 67 | + } else { |
| 68 | + auto ii = i; |
| 69 | + auto jj = (j > len[1]) ? 2*len[1]-j : j; |
| 70 | + auto kk = (k > len[2]) ? 2*len[2]-k : k; |
| 71 | + auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z); |
| 72 | + for (int koff = 0; koff < nimages[2]; ++koff) { |
| 73 | + int k2 = (koff == 0) ? k : 2*len[2]-k; |
| 74 | + if (k2 == 2*len[2]) { continue; } |
| 75 | + for (int joff = 0; joff < nimages[1]; ++joff) { |
| 76 | + int j2 = (joff == 0) ? j : 2*len[1]-j; |
| 77 | + if (j2 == 2*len[1]) { continue; } |
| 78 | + for (int ioff = 0; ioff < nimages[0]; ++ioff) { |
| 79 | + int i2 = (ioff == 0) ? i : 2*len[0]-i; |
| 80 | + if (i2 == 2*len[0]) { continue; } |
| 81 | + a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G; |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + }); |
| 87 | + } |
| 88 | + |
| 89 | + m_r2c.forward(m_r2c.m_rx); |
| 90 | + |
| 91 | + auto [sd, ord] = m_r2c.getSpectralData(); |
| 92 | + amrex::ignore_unused(ord); |
| 93 | + auto const* srcfab = detail::get_fab(*sd); |
| 94 | + if (srcfab) { |
| 95 | + auto* dstfab = detail::get_fab(m_G_fft); |
| 96 | + if (dstfab) { |
| 97 | +#if defined(AMREX_USE_GPU) |
| 98 | + Gpu::dtod_memcpy_async |
| 99 | +#else |
| 100 | + std::memcpy |
| 101 | +#endif |
| 102 | + (dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes()); |
| 103 | + } else { |
| 104 | + amrex::Abort("FFT::OpenBCSolver: how did this happen"); |
| 105 | + } |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +template <typename T> |
| 110 | +void OpenBCSolver<T>::solve (MF& phi, MF const& rho) |
| 111 | +{ |
| 112 | + auto& inmf = m_r2c.m_rx; |
| 113 | + inmf.setVal(T(0)); |
| 114 | + inmf.ParallelCopy(rho, 0, 0, 1); |
| 115 | + |
| 116 | + m_r2c.forward(inmf); |
| 117 | + |
| 118 | + auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts()); |
| 119 | + |
| 120 | + auto const* gfab = detail::get_fab(m_G_fft); |
| 121 | + if (gfab) { |
| 122 | + auto [sd, ord] = m_r2c.getSpectralData(); |
| 123 | + amrex::ignore_unused(ord); |
| 124 | + auto* rhofab = detail::get_fab(*sd); |
| 125 | + if (rhofab) { |
| 126 | + auto* pdst = rhofab->dataPtr(); |
| 127 | + auto const* psrc = gfab->dataPtr(); |
| 128 | + amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i) |
| 129 | + { |
| 130 | + pdst[i] *= psrc[i] * scaling_factor; |
| 131 | + }); |
| 132 | + } else { |
| 133 | + amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?"); |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + m_r2c.backward_doit(phi, phi.nGrowVect()); |
| 138 | +} |
| 139 | + |
| 140 | +} |
| 141 | + |
| 142 | +#endif |
0 commit comments