@@ -221,13 +221,12 @@ uint32_t rocmFmhaWrapper::runCKFmha(void* q,
221221 // nullptr, // o_acc_buf.GetDeviceBuffer(),
222222 softmax_lse_,
223223 output,
224+ seqstart_q, // seqstart_q_ptr
225+ seqstart_k, // seqstart_k_ptr
226+ nullptr , // seqlen_q_ptr
227+ nullptr , // seqlen_k_ptr
224228 nullptr , // cu_seqlen_q_ptr
225- nullptr , // cu_seqlen_kv_ptr
226- seqstart_q,
227- seqstart_k,
228- nullptr , // seqlen_kpads
229- nullptr , // seqstart_padded_q_ptr
230- nullptr , // seqstart_padded_k_ptr
229+ nullptr , // cu_seqlen_k_ptr
231230 shape_seqlen_q,
232231 shape_seqlen_k,
233232 batch,
@@ -489,13 +488,12 @@ uint32_t rocmFmhaWrapper::runCKFmhaV2(void* q,
489488 // nullptr, // o_acc_buf.GetDeviceBuffer(),
490489 softmax_lse_,
491490 output,
491+ seqstart_q, // seqstart_q_ptr
492+ seqstart_k, // seqstart_k_ptr
493+ nullptr , // seqlen_q_ptr
494+ nullptr , // seqlen_k_ptr
492495 nullptr , // cu_seqlen_q_ptr
493- nullptr , // cu_seqlen_kv_ptr
494- seqstart_q,
495- seqstart_k,
496- nullptr , // seqlen_kpads
497- nullptr , // seqstart_padded_q_ptr
498- nullptr , // seqstart_padded_k_ptr
496+ nullptr , // cu_seqlen_k_ptr
499497 shape_seqlen_q,
500498 shape_seqlen_k,
501499 batch,
@@ -759,13 +757,12 @@ uint32_t rocmFmhaWrapper::runCKFmhaMLA(void* q,
759757 // nullptr, // o_acc_buf.GetDeviceBuffer(),
760758 softmax_lse_,
761759 output,
760+ seqstart_q, // seqstart_q_ptr
761+ seqstart_k, // seqstart_k_ptr
762+ nullptr , // seqlen_q_ptr
763+ nullptr , // seqlen_k_ptr
762764 nullptr , // cu_seqlen_q_ptr
763- nullptr , // cu_seqlen_kv_ptr
764- seqstart_q,
765- seqstart_k,
766- nullptr , // seqlen_kpads
767- nullptr , // seqstart_padded_q_ptr
768- nullptr , // seqstart_padded_k_ptr
765+ nullptr , // cu_seqlen_k_ptr
769766 shape_seqlen_q,
770767 shape_seqlen_k,
771768 batch,
0 commit comments