2929#include < include/ntt_gpu/ntt.cuh>
3030#include < limits>
3131#include < vector>
32+ #include < algorithm>
3233
3334namespace cufhe {
3435template <class P = TFHEpp::lvl1param>
@@ -321,8 +322,8 @@ __device__ inline void __SampleExtractIndex__(typename P::T* const res, const ty
321322 }else {
322323 const uint k = i >> P::nbit;
323324 const uint n = i & nmask;
324- if (n <= index) res[index ] = in[k*P::n + index - n];
325- else res[index ] = -in[k*P::n + P::n + index-n];
325+ if (n <= index) res[i ] = in[k*P::n + index - n];
326+ else res[i ] = -in[k*P::n + P::n + index-n];
326327 }
327328 }
328329}
@@ -335,9 +336,12 @@ __device__ inline void __HomGate__(typename brP::targetP::T* const out,
335336 const CuNTTHandler<> ntt)
336337{
337338 __shared__ typename iksP::targetP::T tlwe[iksP::targetP::k*iksP::targetP::n+1 ];
338- __shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1 )*brP::targetP::n];
339339
340340 IdentityKeySwitchPreAdd<iksP, casign, cbsign, offset>(tlwe, in0, in1, ksk);
341+ __syncthreads ();
342+
343+ __shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1 )*brP::targetP::n];
344+
341345 __BlindRotate__<brP>(trlwe, tlwe, μ, bk,ntt);
342346 __SampleExtractIndex__<typename brP::targetP,0 >(out,trlwe);
343347 __threadfence ();
@@ -601,8 +605,10 @@ template<class P>
601605__global__ __launch_bounds__ (NUM_THREAD4HOMGATE) void __CopyBootstrap__(
602606 typename P::T* const out, const typename P::T* const in)
603607{
604- const uint32_t tid = ThisThreadRankInBlock ();
605- out[tid] = in[tid];
608+ const uint tid = ThisThreadRankInBlock ();
609+ const uint bdim = ThisBlockSize ();
610+ for (int i = tid; i <= P::k*P::n; i += bdim)
611+ out[i] = in[i];
606612 __syncthreads ();
607613 __threadfence ();
608614}
@@ -611,8 +617,10 @@ template<class P>
611617__global__ __launch_bounds__ (NUM_THREAD4HOMGATE) void __NotBootstrap__(
612618 typename P::T* const out, const typename P::T* const in)
613619{
614- const uint32_t tid = ThisThreadRankInBlock ();
615- out[tid] = -in[tid];
620+ const uint tid = ThisThreadRankInBlock ();
621+ const uint bdim = ThisBlockSize ();
622+ for (int i = tid; i <= P::k*P::n; i += bdim)
623+ out[i] = -in[i];
616624 __syncthreads ();
617625 __threadfence ();
618626}
@@ -627,17 +635,18 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __MuxBootstrap__(
627635 __shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1 ];
628636
629637 IdentityKeySwitchPreAdd<iksP, 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
630- __threadfence ();
638+ __syncthreads ();
631639 __shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1 )*brP::targetP::n];
632640 __BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
633641 __SampleExtractIndex__<typename brP::targetP,0 >(out, tlwe1);
634642
635- IdentityKeySwitchPreAdd<iksP, 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
636- __threadfence ();
643+ IdentityKeySwitchPreAdd<iksP, - 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
644+ __syncthreads ();
637645 __shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1 )*brP::targetP::n];
638646 __BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
639647 __SampleExtractIndex__<typename brP::targetP,0 >(tlwe1, tlwe0);
640- __threadfence ();
648+
649+ __syncthreads ();
641650
642651 volatile const uint32_t tid = ThisThreadRankInBlock ();
643652 volatile const uint32_t bdim = ThisBlockSize ();
@@ -661,18 +670,19 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __NMuxBootstrap__(
661670 __shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1 ];
662671
663672 IdentityKeySwitchPreAdd<iksP, 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
664- __threadfence ();
673+ __syncthreads ();
665674 __shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1 )*brP::targetP::n];
666675 __BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
667676 __SampleExtractIndex__<typename brP::targetP,0 >(out, tlwe1);
668677
669- IdentityKeySwitchPreAdd<iksP, 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
670- __threadfence ();
678+ IdentityKeySwitchPreAdd<iksP, - 1 , 1 , -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
679+ __syncthreads ();
671680 __shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1 )*brP::targetP::n];
672681 __BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
673682 __SampleExtractIndex__<typename brP::targetP,0 >(tlwe1, tlwe0);
674683
675- __threadfence ();
684+ __syncthreads ();
685+
676686
677687 volatile const uint32_t tid = ThisThreadRankInBlock ();
678688 volatile const uint32_t bdim = ThisBlockSize ();
@@ -1090,7 +1100,7 @@ template<class P>
10901100void CopyBootstrap (typename P::T* const out, const typename P::T* const in,
10911101 const cudaStream_t st, const int gpuNum)
10921102{
1093- __CopyBootstrap__<P><<<1 , P::n + 1 , 0 , st>>> (out, in);
1103+ __CopyBootstrap__<P><<<1 , std::min( P::n + 1 ,NUM_THREAD4HOMGATE) , 0 , st>>> (out, in);
10941104 CuCheckError ();
10951105}
10961106#define INST (P ) \
@@ -1104,7 +1114,7 @@ template<class P>
11041114void NotBootstrap (typename P::T* const out, const typename P::T* const in,
11051115 const cudaStream_t st, const int gpuNum)
11061116{
1107- __NotBootstrap__<P><<<1 , P::n + 1 , 0 , st>>> (out, in);
1117+ __NotBootstrap__<P><<<1 , std::min( P::n + 1 ,NUM_THREAD4HOMGATE) , 0 , st>>> (out, in);
11081118 CuCheckError ();
11091119}
11101120#define INST (P ) \
0 commit comments