Skip to content

Commit 6515610

Browse files
CUDA: fix should_use_mmvf for ne11 == 1 (#17085)
* CUDA: fix should_use_mmvf for ne11 == 1 * Apply suggestion from @am17an Co-authored-by: Aman Gupta <[email protected]> --------- Co-authored-by: Aman Gupta <[email protected]>
1 parent 7956bb4 commit 6515610

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

ggml/src/ggml-cuda/mmf.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,13 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
129129
if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
130130
return false;
131131
}
132-
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
132+
133+
if (src0_nb[0] != ts) {
134+
return false;
135+
}
136+
137+
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
138+
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
133139
if (src0_nb[i] % (2*ts) != 0) {
134140
return false;
135141
}

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,19 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0
720720
if (src0_ne[0] % 2 != 0) {
721721
return false;
722722
}
723+
723724
const size_t ts = ggml_type_size(type);
724-
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
725+
if (src0_nb[0] != ts) {
726+
return false;
727+
}
728+
729+
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
730+
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
725731
if (src0_nb[i] % (2*ts) != 0) {
726732
return false;
727733
}
728734
}
735+
729736
switch (type) {
730737
case GGML_TYPE_F32:
731738
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {

0 commit comments

Comments
 (0)