Skip to content

Commit cea6817

Browse files
committed
bump TFHEpp to latest and fix keyswitching algorithm
1 parent 4fca695 commit cea6817

File tree

4 files changed

+61
-40
lines changed

4 files changed

+61
-40
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
cmake_minimum_required(VERSION 3.18)
2-
project(cuFHE LANGUAGES CUDA CXX)
2+
project(cuFHE LANGUAGES CUDA CXX C)
33

44
set(CMAKE_CXX_STANDARD 20)
55
find_package(CUDAToolkit REQUIRED)
66
set(CMAKE_CUDA_FLAGS "--ptxas-options=-v")
77

8-
option(USE_RANDEN "Use randen as CSPRNG" ON)
8+
option(USE_RANDEN "Use randen as CSPRNG" OFF)
99
option(USE_80BIT_SECURITY "Use 80bit security parameter(faster)" OFF)
1010
option(USE_CGGI19 "Use the parameter set proposed in CGGI19" OFF)
1111
option(USE_CONCRETE "Use the parameter set proposed in CONCRETE" OFF)

include/cufhe_gpu.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ struct Ctxt {
125125
Ctxt(const Ctxt& that) = delete;
126126
Ctxt& operator=(const Ctxt& that) = delete;
127127

128-
TFHEpp::TLWE<P> tlwehost;
128+
alignas(64) TFHEpp::TLWE<P> tlwehost;
129129

130130
std::vector<typename P::T*> tlwedevices;
131131
};

include/keyswitch_gpu.cuh

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,40 @@ namespace cufhe{
99

1010
extern std::vector<TFHEpp::lvl0param::T*> ksk_devs;
1111

12+
13+
template <class P>
14+
__device__ constexpr typename P::domainP::T iksoffsetgen()
15+
{
16+
typename P::domainP::T offset = 0;
17+
for (int i = 1; i <= P::t; i++)
18+
offset +=
19+
(1ULL << P::basebit) / 2 *
20+
(1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
21+
i * P::basebit));
22+
return offset;
23+
}
24+
1225
template <class P>
1326
__device__ inline void KeySwitch(typename P::targetP::T* const lwe,
1427
const typename P::domainP::T* const tlwe,
1528
const typename P::targetP::T* const ksk)
1629
{
17-
constexpr typename P::domainP::T decomp_mask = (1U << P::basebit) - 1;
18-
constexpr typename P::domainP::T decomp_offset =
19-
1U << (std::numeric_limits<typename P::domainP::T>::digits - 1 -
20-
P::t * P::basebit);
30+
constexpr uint domain_digit =
31+
std::numeric_limits<typename P::domainP::T>::digits;
32+
constexpr uint target_digit =
33+
std::numeric_limits<typename P::targetP::T>::digits;
34+
constexpr typename P::domainP::T roundoffset =
35+
(P::basebit * P::t) < domain_digit
36+
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
37+
: 0;
38+
constexpr typename P::domainP::T decompoffset = iksoffsetgen<P>();
39+
constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1;
40+
constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1);
2141
const uint32_t tid = ThisThreadRankInBlock();
2242
const uint32_t bdim = ThisBlockSize();
2343
for (int i = tid; i <= P::targetP::k*P::targetP::n; i += bdim) {
2444
typename P::targetP::T res = 0;
2545
if (i == P::targetP::k*P::targetP::n){
26-
constexpr uint domain_digit =
27-
std::numeric_limits<typename P::domainP::T>::digits;
28-
constexpr uint target_digit =
29-
std::numeric_limits<typename P::targetP::T>::digits;
3046
if constexpr (domain_digit == target_digit)
3147
res = tlwe[P::domainP::k * P::domainP::n];
3248
else if constexpr (domain_digit > target_digit)
@@ -40,20 +56,20 @@ __device__ inline void KeySwitch(typename P::targetP::T* const lwe,
4056
tmp = tlwe[0];
4157
else
4258
tmp = -tlwe[P::domainP::k*P::domainP::n - j];
43-
tmp += decomp_offset;
59+
tmp += decompoffset + roundoffset;
4460
for (int k = 0; k < P::t; k++) {
45-
typename P::domainP::T val =
46-
(tmp >>
61+
const int32_t val =
62+
((tmp >>
4763
(std::numeric_limits<typename P::domainP::T>::digits -
4864
(k + 1) * P::basebit)) &
49-
decomp_mask;
50-
if (val != 0) {
51-
constexpr int numbase = (1 << P::basebit) - 1;
52-
res -= ksk[j * (P::t * numbase *
65+
mask) - halfbase;
66+
constexpr int numbase = 1 << (P::basebit-1);
67+
const typename P::targetP::T kskelem = ksk[j * (P::t * numbase *
5368
(P::targetP::k*P::targetP::n + 1)) +
5469
k * (numbase * (P::targetP::k*P::targetP::n + 1)) +
55-
(val - 1) * (P::targetP::k*P::targetP::n + 1) + i];
56-
}
70+
(abs(val) - 1) * (P::targetP::k*P::targetP::n + 1) + i];
71+
if (val > 0) res -= kskelem;
72+
else if (val < 0) res += kskelem;
5773
}
5874
}
5975
lwe[i] = res;
@@ -66,19 +82,22 @@ __device__ inline void IdentityKeySwitchPreAdd(typename P::targetP::T* const lwe
6682
const typename P::domainP::T* const inb,
6783
const typename P::targetP::T* const ksk)
6884
{
69-
constexpr typename P::domainP::T decomp_mask = (1U << P::basebit) - 1;
70-
constexpr typename P::domainP::T decomp_offset =
71-
1U << (std::numeric_limits<typename P::domainP::T>::digits - 1 -
72-
P::t * P::basebit);
85+
constexpr uint domain_digit =
86+
std::numeric_limits<typename P::domainP::T>::digits;
87+
constexpr uint target_digit =
88+
std::numeric_limits<typename P::targetP::T>::digits;
89+
constexpr typename P::domainP::T roundoffset =
90+
(P::basebit * P::t) < domain_digit
91+
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
92+
: 0;
93+
constexpr typename P::domainP::T decompoffset = iksoffsetgen<P>();
94+
constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1;
95+
constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1);
7396
const uint32_t tid = ThisThreadRankInBlock();
7497
const uint32_t bdim = ThisBlockSize();
7598
for (int i = tid; i <= P::targetP::k*P::targetP::n; i += bdim) {
7699
typename P::targetP::T res = 0;
77100
if (i == P::targetP::k*P::targetP::n){
78-
constexpr uint domain_digit =
79-
std::numeric_limits<typename P::domainP::T>::digits;
80-
constexpr uint target_digit =
81-
std::numeric_limits<typename P::targetP::T>::digits;
82101
const typename P::domainP::T added = casign*ina[P::domainP::k*P::domainP::n]+ cbsign*inb[P::domainP::k*P::domainP::n] + offset;
83102
if constexpr (domain_digit == target_digit)
84103
res = added;
@@ -89,20 +108,22 @@ __device__ inline void IdentityKeySwitchPreAdd(typename P::targetP::T* const lwe
89108
}
90109
for (int j = 0; j < P::domainP::k*P::domainP::n; j++) {
91110
typename P::domainP::T tmp;
92-
tmp = casign*ina[j]+ cbsign*inb[j] + 0 + decomp_offset;
111+
tmp = casign*ina[j] + cbsign*inb[j] + 0 + decompoffset + roundoffset;
93112
for (int k = 0; k < P::t; k++) {
94-
typename P::domainP::T val =
95-
(tmp >>
113+
const int32_t val =
114+
((tmp >>
96115
(std::numeric_limits<typename P::domainP::T>::digits -
97116
(k + 1) * P::basebit)) &
98-
decomp_mask;
99-
if (val != 0) {
100-
constexpr int numbase = (1 << P::basebit) - 1;
101-
res -= ksk[j * (P::t * numbase *
102-
(P::targetP::k*P::targetP::n + 1)) +
103-
k * (numbase * (P::targetP::k*P::targetP::n + 1)) +
104-
(val - 1) * (P::targetP::k*P::targetP::n + 1) + i];
105-
}
117+
mask) - halfbase;
118+
constexpr int numbase = 1 << (P::basebit-1);
119+
if (val > 0) res -= ksk[j * (P::t * numbase *
120+
(P::targetP::k*P::targetP::n + 1)) +
121+
k * (numbase * (P::targetP::k*P::targetP::n + 1)) +
122+
(val - 1) * (P::targetP::k*P::targetP::n + 1) + i];
123+
else if (val < 0) res += ksk[j * (P::t * numbase *
124+
(P::targetP::k*P::targetP::n + 1)) +
125+
k * (numbase * (P::targetP::k*P::targetP::n + 1)) +
126+
(-val - 1) * (P::targetP::k*P::targetP::n + 1) + i];
106127
}
107128
}
108129
lwe[i] = res;

thirdparties/TFHEpp

Submodule TFHEpp updated 85 files

0 commit comments

Comments
 (0)