Skip to content

Commit e31946f

Browse files
authored
[flashinfer] fix FI all2all with FI cutlass moe (#28166)
Signed-off-by: Xiaozhu <[email protected]>
1 parent bde5039 commit e31946f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,13 @@ def flashinfer_alltoall_dispatch(
233233
max_num_token = (
234234
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
235235
)
236+
orig_topk_weights_dtype = topk_weights.dtype
236237
alltoall_info, topk_ids, topk_weights, _ = (
237238
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
238239
topk_ids,
239240
topk_weights,
240241
None,
241-
all2all_manager.prepare_workspace,
242+
all2all_manager.prepare_workspace_tensor,
242243
max_num_token,
243244
ep_rank,
244245
ep_size,
@@ -247,6 +248,7 @@ def flashinfer_alltoall_dispatch(
247248
top_k,
248249
)
249250
)
251+
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
250252

251253
x, x_sf = moe_kernel_quantize_input(
252254
x,

0 commit comments

Comments
 (0)