2727
2828template <int kNThreads_ , int kNItems_ , int kNRows_ , bool kIsEvenLen_ ,
2929 bool kIsVariableB_ , bool kIsVariableC_ ,
30- bool kHasZ_ , bool kVarlen_ , typename input_t_, typename weight_t_>
30+ bool kHasZ_ , bool kVarlen_ , typename input_t_, typename weight_t_, typename state_t_ >
3131struct Selective_Scan_fwd_kernel_traits {
3232 static_assert (kNItems_ % 4 == 0 );
3333 using input_t = input_t_;
3434 using weight_t = weight_t_;
35+ using state_t = state_t_;
3536 static constexpr int kNThreads = kNThreads_ ;
3637 // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
3738 static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3 ;
@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
132133 input_t *Bvar = reinterpret_cast <input_t *>(params.B_ptr ) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride ;
133134 weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
134135 input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride ;
135- input_t *ssm_states = reinterpret_cast <input_t *>(params.ssm_states_ptr ) +
136+ typename Ktraits:: state_t *ssm_states = reinterpret_cast <typename Ktraits:: state_t *>(params.ssm_states_ptr ) +
136137 cache_index * params.ssm_states_batch_stride +
137138 dim_id * kNRows * params.ssm_states_dim_stride ;
138139
@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
261262 if (threadIdx .x == 0 ) {
262263 smem_running_prefix[state_idx] = prefix_op.running_prefix ;
263264 if (chunk == n_chunks - 1 ) {
264- ssm_states[state_idx * params.ssm_states_dstate_stride ] = input_t (prefix_op.running_prefix .y );
265+ ssm_states[state_idx * params.ssm_states_dstate_stride ] = typename Ktraits::state_t (prefix_op.running_prefix .y );
265266 }
266267 }
267268 #pragma unroll
@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
310311 }
311312}
312313
313- template <int kNThreads , int kNItems , typename input_t , typename weight_t >
314+ template <int kNThreads , int kNItems , typename input_t , typename weight_t , typename state_t >
314315void selective_scan_fwd_launch (SSMParamsBase ¶ms, cudaStream_t stream) {
315316 // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
316317 // processing 1 row.
@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
321322 BOOL_SWITCH (params.seqlen % (kNThreads * kNItems ) == 0 , kIsEvenLen , [&] {
322323 BOOL_SWITCH (params.z_ptr != nullptr , kHasZ , [&] {
323324 BOOL_SWITCH (params.query_start_loc_ptr != nullptr , kVarlen , [&] {
324- using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kVarlen , input_t , weight_t >;
325+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kVarlen , input_t , weight_t , state_t >;
325326 constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof (typename Ktraits::scan_t );
326327 dim3 grid (params.batch , params.dim / kNRows );
327328 auto kernel = &selective_scan_fwd_kernel<Ktraits>;
@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
341342 });
342343}
343344
344- template <typename input_t , typename weight_t >
345+ template <typename input_t , typename weight_t , typename state_t >
345346void selective_scan_fwd_cuda (SSMParamsBase ¶ms, cudaStream_t stream) {
346347
347348 #ifndef USE_ROCM
348349 if (params.seqlen <= 128 ) {
349- selective_scan_fwd_launch<32 , 4 , input_t , weight_t >(params, stream);
350+ selective_scan_fwd_launch<32 , 4 , input_t , weight_t , state_t >(params, stream);
350351 } else if (params.seqlen <= 256 ) {
351- selective_scan_fwd_launch<32 , 8 , input_t , weight_t >(params, stream);
352+ selective_scan_fwd_launch<32 , 8 , input_t , weight_t , state_t >(params, stream);
352353 } else if (params.seqlen <= 512 ) {
353- selective_scan_fwd_launch<32 , 16 , input_t , weight_t >(params, stream);
354+ selective_scan_fwd_launch<32 , 16 , input_t , weight_t , state_t >(params, stream);
354355 } else if (params.seqlen <= 1024 ) {
355- selective_scan_fwd_launch<64 , 16 , input_t , weight_t >(params, stream);
356+ selective_scan_fwd_launch<64 , 16 , input_t , weight_t , state_t >(params, stream);
356357 } else {
357- selective_scan_fwd_launch<128 , 16 , input_t , weight_t >(params, stream);
358+ selective_scan_fwd_launch<128 , 16 , input_t , weight_t , state_t >(params, stream);
358359 }
359360 #else
360361 if (params.seqlen <= 256 ) {
361- selective_scan_fwd_launch<64 , 4 , input_t , weight_t >(params, stream);
362+ selective_scan_fwd_launch<64 , 4 , input_t , weight_t , state_t >(params, stream);
362363 } else if (params.seqlen <= 512 ) {
363- selective_scan_fwd_launch<64 , 8 , input_t , weight_t >(params, stream);
364+ selective_scan_fwd_launch<64 , 8 , input_t , weight_t , state_t >(params, stream);
364365 } else if (params.seqlen <= 1024 ) {
365- selective_scan_fwd_launch<64 , 16 , input_t , weight_t >(params, stream);
366+ selective_scan_fwd_launch<64 , 16 , input_t , weight_t , state_t >(params, stream);
366367 } else {
367- selective_scan_fwd_launch<128 , 16 , input_t , weight_t >(params, stream);
368+ selective_scan_fwd_launch<128 , 16 , input_t , weight_t , state_t >(params, stream);
368369 }
369370 #endif
370371}
371372
372- template void selective_scan_fwd_cuda<at::BFloat16, float >(SSMParamsBase ¶ms, cudaStream_t stream);
373- template void selective_scan_fwd_cuda<at::Half, float >(SSMParamsBase ¶ms, cudaStream_t stream);
374- template void selective_scan_fwd_cuda<float , float >(SSMParamsBase ¶ms, cudaStream_t stream);
373+ template void selective_scan_fwd_cuda<at::BFloat16, float , at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
374+ template void selective_scan_fwd_cuda<at::BFloat16, float , float >(SSMParamsBase ¶ms, cudaStream_t stream);
375+ template void selective_scan_fwd_cuda<at::Half, float , at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
376+ template void selective_scan_fwd_cuda<at::Half, float , float >(SSMParamsBase ¶ms, cudaStream_t stream);
377+ template void selective_scan_fwd_cuda<float , float , float >(SSMParamsBase ¶ms, cudaStream_t stream);
375378
376379#define CHECK_SHAPE (x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ " )" )
377380
378- #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (ITYPE, NAME, ...) \
381+ #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (ITYPE, STYPE, NAME, ...) \
379382 if (ITYPE == at::ScalarType::Half) { \
380383 using input_t = at::Half; \
381384 using weight_t = float ; \
382- __VA_ARGS__ (); \
385+ if (STYPE == at::ScalarType::Half) { \
386+ using state_t = at::Half; \
387+ __VA_ARGS__ (); \
388+ } else if (STYPE == at::ScalarType::Float) { \
389+ using state_t = float ; \
390+ __VA_ARGS__ (); \
391+ } else { \
392+ AT_ERROR (#NAME, " not implemented for state type '" , toString (STYPE), " '" ); \
393+ } \
383394 } else if (ITYPE == at::ScalarType::BFloat16) { \
384395 using input_t = at::BFloat16; \
385396 using weight_t = float ; \
386- __VA_ARGS__ (); \
397+ if (STYPE == at::ScalarType::BFloat16) { \
398+ using state_t = at::BFloat16; \
399+ __VA_ARGS__ (); \
400+ } else if (STYPE == at::ScalarType::Float) { \
401+ using state_t = float ; \
402+ __VA_ARGS__ (); \
403+ } else { \
404+ AT_ERROR (#NAME, " not implemented for state type '" , toString (STYPE), " '" ); \
405+ } \
387406 } else if (ITYPE == at::ScalarType::Float) { \
388407 using input_t = float ; \
389408 using weight_t = float ; \
409+ using state_t = float ; \
390410 __VA_ARGS__ (); \
391411 } else { \
392412 AT_ERROR (#NAME, " not implemented for input type '" , toString (ITYPE), " '" ); \
393413 }
394414
395415
396- template <typename input_t , typename weight_t >
416+ template <typename input_t , typename weight_t , typename state_t >
397417void selective_scan_fwd_cuda (SSMParamsBase ¶ms, cudaStream_t stream);
398418
399419void set_ssm_params_fwd (SSMParamsBase ¶ms,
@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
648668
649669 // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
650670 at::Tensor out = delta;
651- TORCH_CHECK (ssm_states.scalar_type () == input_type);
671+ // ssm_states can now be either the same as input_type or float32
672+ auto state_type = ssm_states.scalar_type ();
673+ TORCH_CHECK (state_type == input_type || state_type == at::ScalarType::Float);
652674 TORCH_CHECK (ssm_states.is_cuda ());
653675 TORCH_CHECK (ssm_states.stride (-1 ) == 1 );
654676
@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
670692
671693 const at::cuda::OptionalCUDAGuard device_guard (device_of (u));
672694 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
673- DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (u.scalar_type (), " selective_scan_fwd" , [&] {
674- selective_scan_fwd_cuda<input_t , weight_t >(params, stream);
695+ DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (u.scalar_type (), ssm_states. scalar_type (), " selective_scan_fwd" , [&] {
696+ selective_scan_fwd_cuda<input_t , weight_t , state_t >(params, stream);
675697 });
676698}
0 commit comments