Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
# GIT_TAG v3.5.1
GIT_TAG dbdae514e03f83968f8b7dd4fb064071b9bfbdd1
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
GIT_SHALLOW FALSE
)
FetchContent_MakeAvailable(cutlass)

Expand Down Expand Up @@ -261,7 +262,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/grouped_gemm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
8 changes: 8 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// CUTLASS w8a8 grouped GEMM // TODO complete this
ops.def(
"cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, "
" Tensor b_scales, Tensor problem_sizes, "
" Tensor out_offsets, Tensor a_offsets, "
" Tensor b_offsets, Tensor a_scales_offsets, "
" Tensor b_scales_offsets) -> ()");
ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
Loading