Skip to content

Commit a32cf4b

Browse files
committed
finally working with lvl1param!
1 parent 125b24b commit a32cf4b

File tree

7 files changed

+85
-50
lines changed

7 files changed

+85
-50
lines changed

include/gatebootstrapping_gpu.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ __device__ constexpr typename P::T offsetgen()
2727
}
2828

2929
template <class P>
30-
__device__ inline void RotatedTestVector(TFHEpp::lvl1param::T* tlwe,
30+
__device__ inline void RotatedTestVector(typename P::T* tlwe,
3131
const int32_t bar,
3232
const typename P::T μ)
3333
{
@@ -202,13 +202,13 @@ __device__ inline void __BlindRotatePreAdd__(typename P::targetP::T* const out,
202202
{
203203
const uint32_t bar =
204204
2 * P::targetP::n -
205-
modSwitchFromTorus<P::targetP>(offset + casign * in0[P::domainP::k*P::domainP::n] +
205+
modSwitchFromTorus<typename P::targetP>(offset + casign * in0[P::domainP::k*P::domainP::n] +
206206
cbsign * in1[P::domainP::k*P::domainP::n]);
207-
RotatedTestVector<P::targetP>(out, bar, P::targetP::μ);
207+
RotatedTestVector<typename P::targetP>(out, bar, P::targetP::μ);
208208
}
209209

210210
// accumulate
211-
for (int i = 0; i < P::domainP::n; i++) { // lvl0param::n iterations
211+
for (int i = 0; i < P::domainP::k*P::domainP::n; i++) { // lvl0param::n iterations
212212
const uint32_t bar = modSwitchFromTorus<P::targetP>(0 + casign * in0[i] +
213213
cbsign * in1[i]);
214214
Accumulate<P>(out, sh_acc_ntt, bar,

include/keyswitch_gpu.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ __device__ inline void KeySwitch(typename P::targetP::T* const lwe,
2222
const uint32_t bdim = ThisBlockSize();
2323
for (int i = tid; i <= P::targetP::k*P::targetP::n; i += bdim) {
2424
typename P::targetP::T res = 0;
25-
if (i == P::targetP::n) res = tlwe[P::domainP::n];
25+
if (i == P::targetP::k*P::targetP::n) res = tlwe[P::domainP::k*P::domainP::n];
2626
for (int j = 0; j < P::domainP::k*P::domainP::n; j++) {
2727
typename P::domainP::T tmp;
2828
if (j == 0)
@@ -66,7 +66,7 @@ __device__ inline void IdentityKeySwitchPreAdd(typename P::targetP::T* const lwe
6666
if (i == P::targetP::k*P::targetP::n) res = casign*ina[P::domainP::k*P::domainP::n]+ cbsign*inb[P::domainP::k*P::domainP::n] + offset;
6767
for (int j = 0; j < P::domainP::k*P::domainP::n; j++) {
6868
typename P::domainP::T tmp;
69-
tmp = casign*ina[j]+ cbsign*inb[j] + decomp_offset;
69+
tmp = casign*ina[j]+ cbsign*inb[j] + 0 + decomp_offset;
7070
for (int k = 0; k < P::t; k++) {
7171
typename P::domainP::T val =
7272
(tmp >>

src/bootstrap_gpu.cu

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <include/ntt_gpu/ntt.cuh>
3030
#include <limits>
3131
#include <vector>
32+
#include <algorithm>
3233

3334
namespace cufhe {
3435
template<class P = TFHEpp::lvl1param>
@@ -321,8 +322,8 @@ __device__ inline void __SampleExtractIndex__(typename P::T* const res, const ty
321322
}else {
322323
const uint k = i >> P::nbit;
323324
const uint n = i & nmask;
324-
if (n <= index) res[index] = in[k*P::n + index - n];
325-
else res[index] = -in[k*P::n + P::n + index-n];
325+
if (n <= index) res[i] = in[k*P::n + index - n];
326+
else res[i] = -in[k*P::n + P::n + index-n];
326327
}
327328
}
328329
}
@@ -335,9 +336,12 @@ __device__ inline void __HomGate__(typename brP::targetP::T* const out,
335336
const CuNTTHandler<> ntt)
336337
{
337338
__shared__ typename iksP::targetP::T tlwe[iksP::targetP::k*iksP::targetP::n+1];
338-
__shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1)*brP::targetP::n];
339339

340340
IdentityKeySwitchPreAdd<iksP, casign, cbsign, offset>(tlwe, in0, in1, ksk);
341+
__syncthreads();
342+
343+
__shared__ typename brP::targetP::T trlwe[(brP::targetP::k+1)*brP::targetP::n];
344+
341345
__BlindRotate__<brP>(trlwe, tlwe, μ, bk,ntt);
342346
__SampleExtractIndex__<typename brP::targetP,0>(out,trlwe);
343347
__threadfence();
@@ -601,8 +605,10 @@ template<class P>
601605
__global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __CopyBootstrap__(
602606
typename P::T* const out, const typename P::T* const in)
603607
{
604-
const uint32_t tid = ThisThreadRankInBlock();
605-
out[tid] = in[tid];
608+
const uint tid = ThisThreadRankInBlock();
609+
const uint bdim = ThisBlockSize();
610+
for (int i = tid; i <= P::k*P::n; i += bdim)
611+
out[i] = in[i];
606612
__syncthreads();
607613
__threadfence();
608614
}
@@ -611,8 +617,10 @@ template<class P>
611617
__global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __NotBootstrap__(
612618
typename P::T* const out, const typename P::T* const in)
613619
{
614-
const uint32_t tid = ThisThreadRankInBlock();
615-
out[tid] = -in[tid];
620+
const uint tid = ThisThreadRankInBlock();
621+
const uint bdim = ThisBlockSize();
622+
for (int i = tid; i <= P::k*P::n; i += bdim)
623+
out[i] = -in[i];
616624
__syncthreads();
617625
__threadfence();
618626
}
@@ -627,17 +635,18 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __MuxBootstrap__(
627635
__shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1];
628636

629637
IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
630-
__threadfence();
638+
__syncthreads();
631639
__shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1)*brP::targetP::n];
632640
__BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
633641
__SampleExtractIndex__<typename brP::targetP,0>(out, tlwe1);
634642

635-
IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
636-
__threadfence();
643+
IdentityKeySwitchPreAdd<iksP, -1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
644+
__syncthreads();
637645
__shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1)*brP::targetP::n];
638646
__BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
639647
__SampleExtractIndex__<typename brP::targetP,0>(tlwe1, tlwe0);
640-
__threadfence();
648+
649+
__syncthreads();
641650

642651
volatile const uint32_t tid = ThisThreadRankInBlock();
643652
volatile const uint32_t bdim = ThisBlockSize();
@@ -661,18 +670,19 @@ __global__ __launch_bounds__(NUM_THREAD4HOMGATE) void __NMuxBootstrap__(
661670
__shared__ typename iksP::targetP::T tlwelvl0[iksP::targetP::k*iksP::targetP::n+1];
662671

663672
IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in1, ksk);
664-
__threadfence();
673+
__syncthreads();
665674
__shared__ typename brP::targetP::T tlwe1[(brP::targetP::k+1)*brP::targetP::n];
666675
__BlindRotate__<brP>(tlwe1,tlwelvl0,μ,bk,ntt);
667676
__SampleExtractIndex__<typename brP::targetP,0>(out, tlwe1);
668677

669-
IdentityKeySwitchPreAdd<iksP, 1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
670-
__threadfence();
678+
IdentityKeySwitchPreAdd<iksP, -1, 1, -iksP::domainP::μ>(tlwelvl0, inc, in0, ksk);
679+
__syncthreads();
671680
__shared__ typename brP::targetP::T tlwe0[(brP::targetP::k+1)*brP::targetP::n];
672681
__BlindRotate__<brP>(tlwe0,tlwelvl0,μ,bk,ntt);
673682
__SampleExtractIndex__<typename brP::targetP,0>(tlwe1, tlwe0);
674683

675-
__threadfence();
684+
__syncthreads();
685+
676686

677687
volatile const uint32_t tid = ThisThreadRankInBlock();
678688
volatile const uint32_t bdim = ThisBlockSize();
@@ -1090,7 +1100,7 @@ template<class P>
10901100
void CopyBootstrap(typename P::T* const out, const typename P::T* const in,
10911101
const cudaStream_t st, const int gpuNum)
10921102
{
1093-
__CopyBootstrap__<P><<<1, P::n + 1, 0, st>>>(out, in);
1103+
__CopyBootstrap__<P><<<1, std::min(P::n + 1,NUM_THREAD4HOMGATE), 0, st>>>(out, in);
10941104
CuCheckError();
10951105
}
10961106
#define INST(P) \
@@ -1104,7 +1114,7 @@ template<class P>
11041114
void NotBootstrap(typename P::T* const out, const typename P::T* const in,
11051115
const cudaStream_t st, const int gpuNum)
11061116
{
1107-
__NotBootstrap__<P><<<1, P::n + 1, 0, st>>>(out, in);
1117+
__NotBootstrap__<P><<<1, std::min(P::n + 1,NUM_THREAD4HOMGATE), 0, st>>>(out, in);
11081118
CuCheckError();
11091119
}
11101120
#define INST(P) \

0 commit comments

Comments
 (0)