Skip to content

Commit 5dfddfd

Browse files
Use int64_t instead of int32_t in decode attention kernel to avoid index calc overflow (#5095)
Summary: Pull Request resolved: #5095 X-link: https://github.com/facebookresearch/FBGEMM/pull/2101 For KV cache size >4GB, int32_t is not enough to calculate the index. Reviewed By: Aya-ZIbra Differential Revision: D85631787 fbshipit-source-id: a9695be5a6d79b9abd3a349f780878d766ac2ceb
1 parent 61def1a commit 5dfddfd

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/77_blackwell_fmha_gen.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ struct ExampleRunner {
335335

336336
using StrideQ = Stride<_0, _1, Stride<Stride<int, int>, int>>;
337337
using StrideNewK = Stride<_0, _1, Stride<Stride<_0, int>, int>>;
338-
using StrideCacheK = Stride<int, _1, Stride<Stride<_0, int>, int>>;
338+
using StrideCacheK = Stride<int, _1, Stride<Stride<_0, int64_t>, int64_t>>;
339339
using StrideNewV = StrideNewK;
340340
using StrideCacheV = StrideCacheK;
341341
using StrideO = StrideQ;
@@ -474,7 +474,7 @@ struct ExampleRunner {
474474

475475
stride_q = make_stride(_0{}, _1{}, make_stride(make_stride(options.d, options.d * size<3,0,0>(result)), options.d * size<3,0>(result)));
476476
stride_new_k = make_stride(_0{}, _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result)));
477-
stride_cache_k = make_stride(options.d * size<3,0,1>(result), _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result) * get<1>(result)));
477+
stride_cache_k = make_stride(options.d * size<3,0,1>(result), _1{}, make_stride(make_stride(_0{}, static_cast<int64_t>(options.d)), static_cast<int64_t>(options.d) * size<3,0,1>(result) * get<1>(result)));
478478

479479
stride_new_v = stride_new_k;
480480
stride_cache_v = stride_cache_k;

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ struct GenRunner {
9393

9494
using StrideQ =
9595
Stride<_0, _1, Stride<Stride<int, int>, int>>; // Q D ((H, Hr), B)
96-
using StrideNewK = Stride<_0, _1, Stride<Stride<_0, int>, int>>;
96+
using StrideNewK = Stride<_0, _1, Stride<Stride<_0, int64_t>, int64_t>>;
9797
using StrideCacheK =
98-
Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H, Hr), B)
98+
Stride<int, _1, Stride<Stride<_0, int64_t>, int64_t>>; // K D ((H, Hr), B)
9999
using StrideNewV = StrideNewK;
100100
using StrideCacheV = StrideCacheK;
101101
using StrideO = StrideQ;
@@ -198,13 +198,13 @@ struct GenRunner {
198198
stride_new_k = make_stride(
199199
_0{},
200200
_1{},
201-
make_stride(make_stride(_0{}, options.d), options.d * options.h_k));
201+
make_stride(make_stride(_0{}, static_cast<int64_t>(options.d)), static_cast<int64_t>(options.d * options.h_k)));
202202
stride_cache_k = make_stride(
203203
options.d * options.h_k,
204204
_1{},
205205
make_stride(
206-
make_stride(_0{}, options.d),
207-
options.d * options.h_k * options.sk));
206+
make_stride(_0{}, static_cast<int64_t>(options.d)),
207+
static_cast<int64_t>(options.d * options.h_k * options.sk)));
208208

209209
stride_new_v = stride_new_k;
210210
stride_cache_v = stride_cache_k;

0 commit comments

Comments
 (0)