2626#include " dispatch_utils.h"
2727
2828#include " cuda_utils.h"
29+ #include " nvfp4_utils.cuh"
2930
3031namespace vllm {
3132
32- // Get type2 from type or vice versa (applied to half and bfloat16)
33- template <typename T>
34- struct TypeConverter {
35- using Type = half2;
36- }; // keep for generality
37-
38- template <>
39- struct TypeConverter <half2> {
40- using Type = c10::Half;
41- };
42-
43- template <>
44- struct TypeConverter <c10::Half> {
45- using Type = half2;
46- };
47-
48- template <>
49- struct TypeConverter <__nv_bfloat162> {
50- using Type = c10::BFloat16;
51- };
52-
53- template <>
54- struct TypeConverter <c10::BFloat16> {
55- using Type = __nv_bfloat162;
56- };
57-
58- #define ELTS_PER_THREAD 8
59-
60- constexpr int CVT_FP4_ELTS_PER_THREAD = 8 ;
61- constexpr int CVT_FP4_SF_VEC_SIZE = 16 ;
62-
63- // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
64- inline __device__ uint32_t fp32_vec_to_e2m1 (float (&array)[8]) {
65- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
66- uint32_t val;
67- asm volatile (
68- " {\n "
69- " .reg .b8 byte0;\n "
70- " .reg .b8 byte1;\n "
71- " .reg .b8 byte2;\n "
72- " .reg .b8 byte3;\n "
73- " cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n "
74- " cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n "
75- " cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n "
76- " cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n "
77- " mov.b32 %0, {byte0, byte1, byte2, byte3};\n "
78- " }"
79- : " =r" (val)
80- : " f" (array[0 ]), " f" (array[1 ]), " f" (array[2 ]), " f" (array[3 ]),
81- " f" (array[4 ]), " f" (array[5 ]), " f" (array[6 ]), " f" (array[7 ]));
82- return val;
83- #else
84- return 0 ;
85- #endif
86- }
87-
88- // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
89- inline __device__ uint32_t fp32_vec_to_e2m1 (float2 (&array)[4]) {
90- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
91- uint32_t val;
92- asm volatile (
93- " {\n "
94- " .reg .b8 byte0;\n "
95- " .reg .b8 byte1;\n "
96- " .reg .b8 byte2;\n "
97- " .reg .b8 byte3;\n "
98- " cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n "
99- " cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n "
100- " cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n "
101- " cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n "
102- " mov.b32 %0, {byte0, byte1, byte2, byte3};\n "
103- " }"
104- : " =r" (val)
105- : " f" (array[0 ].x ), " f" (array[0 ].y ), " f" (array[1 ].x ), " f" (array[1 ].y ),
106- " f" (array[2 ].x ), " f" (array[2 ].y ), " f" (array[3 ].x ), " f" (array[3 ].y ));
107- return val;
108- #else
109- return 0 ;
110- #endif
111- }
112-
113- // Fast reciprocal.
114- inline __device__ float reciprocal_approximate_ftz (float a) {
115- float b;
116- asm volatile (" rcp.approx.ftz.f32 %0, %1;\n " : " =f" (b) : " f" (a));
117- return b;
118- }
119-
120- template <class SFType , int CVT_FP4_NUM_THREADS_PER_SF>
121- __device__ uint8_t * cvt_quant_to_fp4_get_sf_out_offset (int rowIdx, int colIdx,
122- int numCols,
123- SFType* SFout) {
124- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
125- static_assert (CVT_FP4_NUM_THREADS_PER_SF == 1 ||
126- CVT_FP4_NUM_THREADS_PER_SF == 2 );
127-
128- // One pair of threads write one SF to global memory.
129- // TODO: stage through smem for packed STG.32
130- // is it better than STG.8 from 4 threads ?
131- if (threadIdx .x % CVT_FP4_NUM_THREADS_PER_SF == 0 ) {
132- // SF vector index (16 elements share one SF in the K dimension).
133- int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
134- int32_t mIdx = rowIdx;
135-
136- // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
137- // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
138-
139- int32_t mTileIdx = mIdx / (32 * 4 );
140- // SF vector size 16.
141- int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
142- int32_t numKTiles = (numCols + factor - 1 ) / factor;
143- int64_t mTileStride = numKTiles * 32 * 4 * 4 ;
144-
145- int32_t kTileIdx = (kIdx / 4 );
146- int64_t kTileStride = 32 * 4 * 4 ;
147-
148- // M tile layout [32, 4] is column-major.
149- int32_t outerMIdx = (mIdx % 32 );
150- int64_t outerMStride = 4 * 4 ;
151-
152- int32_t innerMIdx = (mIdx % (32 * 4 )) / 32 ;
153- int64_t innerMStride = 4 ;
154-
155- int32_t innerKIdx = (kIdx % 4 );
156- int64_t innerKStride = 1 ;
157-
158- // Compute the global offset.
159- int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
160- outerMIdx * outerMStride + innerMIdx * innerMStride +
161- innerKIdx * innerKStride;
162-
163- return reinterpret_cast <uint8_t *>(SFout) + SFOffset;
164- }
165- #endif
166- return nullptr ;
167- }
168-
169- // Define a 16 bytes packed data type.
170- template <class Type >
171- struct PackedVec {
172- typename TypeConverter<Type>::Type elts[4 ];
173- };
174-
175- template <>
176- struct PackedVec <__nv_fp8_e4m3> {
177- __nv_fp8x2_e4m3 elts[8 ];
178- };
179-
18033template <class Type >
18134__inline__ __device__ PackedVec<Type> compute_silu (PackedVec<Type>& vec,
18235 PackedVec<Type>& vec2) {
18336 PackedVec<Type> result;
18437#pragma unroll
18538 for (int i = 0 ; i < CVT_FP4_ELTS_PER_THREAD / 2 ; ++i) {
186- if constexpr (std::is_same_v<Type, c10::Half >) {
39+ if constexpr (std::is_same_v<Type, half >) {
18740 half2 val (0 .5f , 0 .5f );
18841 half2 t0 = __hmul2 (vec.elts [i], val);
18942 half2 t1 = __hfma2 (h2tanh (t0), val, val);
@@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
20659 PackedVec<Type>& vec2,
20760 float SFScaleVal,
20861 uint8_t * SFout) {
209- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
21062 PackedVec<Type> out_silu = compute_silu (vec, vec2);
21163 // Get absolute maximum values among the local 8 values.
21264 auto localMax = __habs2 (out_silu.elts [0 ]);
21365
214- // Local maximum value.
215- #pragma unroll
66+ // Local maximum value.
67+ #pragma unroll
21668 for (int i = 1 ; i < CVT_FP4_ELTS_PER_THREAD / 2 ; i++) {
21769 localMax = __hmax2 (localMax, __habs2 (out_silu.elts [i]));
21870 }
@@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
259111 // Convert the input to float.
260112 float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2 ];
261113
262- #pragma unroll
114+ #pragma unroll
263115 for (int i = 0 ; i < CVT_FP4_ELTS_PER_THREAD / 2 ; i++) {
264- if constexpr (std::is_same_v<Type, c10::Half >) {
116+ if constexpr (std::is_same_v<Type, half >) {
265117 fp2Vals[i] = __half22float2 (out_silu.elts [i]);
266118 } else {
267119 fp2Vals[i] = __bfloat1622float2 (out_silu.elts [i]);
@@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
275127
276128 // Write the e2m1 values to global memory.
277129 return e2m1Vec;
278- #else
279- return 0 ;
280- #endif
281130}
282131
283132// Use UE4M3 by default.
284133template <class Type , bool UE8M0_SF = false >
285- __global__ void
286- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
287- __launch_bounds__ (1024 , 4 ) silu_and_cvt_fp16_to_fp4(
288- #else
289- silu_and_cvt_fp16_to_fp4 (
290- #endif
291- int32_t numRows, int32_t numCols, Type const * in, float const * SFScale,
292- uint32_t * out, uint32_t * SFout) {
293- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
134+ __global__ void __launch_bounds__ (1024 , 4 )
135+ silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const * in,
136+ float const * SFScale, uint32_t * out,
137+ uint32_t * SFout) {
294138 using PackedVec = PackedVec<Type>;
295139 static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
296140 (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
@@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4(
328172 in_vec, in_vec2, SFScaleVal, sf_out);
329173 }
330174 }
331- #endif
332175}
333176
334177} // namespace vllm
335178
336- void silu_and_mul_nvfp4_quant (torch::Tensor& output, // [..., d]
337- torch::Tensor& output_sf,
338- torch::Tensor& input, // [..., 2 * d]
339- torch::Tensor& input_sf) {
340- TORCH_CHECK (input.dtype () == torch::kFloat16 ||
341- input.dtype () == torch::kBFloat16 );
179+ void silu_and_mul_nvfp4_quant_sm1xxa (torch::Tensor& output, // [..., d]
180+ torch::Tensor& output_sf,
181+ torch::Tensor& input, // [..., 2 * d]
182+ torch::Tensor& input_sf) {
342183 int32_t m = input.size (0 );
343184 int32_t n = input.size (1 ) / 2 ;
185+
344186 TORCH_CHECK (n % 16 == 0 , " The N dimension must be multiple of 16." );
187+ TORCH_CHECK (input.scalar_type () == at::ScalarType::Half ||
188+ input.scalar_type () == at::ScalarType::BFloat16,
189+ " Unsupported input data type for quantize_to_fp4." );
190+
345191 int multiProcessorCount =
346192 get_device_attribute (cudaDevAttrMultiProcessorCount, -1 );
193+
347194 auto input_sf_ptr = static_cast <float const *>(input_sf.data_ptr ());
348195 auto sf_out = static_cast <int32_t *>(output_sf.data_ptr ());
349196 auto output_ptr = static_cast <int64_t *>(output.data_ptr ());
@@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
352199 dim3 block (std::min (int (n / ELTS_PER_THREAD), 1024 ));
353200 int const numBlocksPerSM = 2048 / block.x ;
354201 dim3 grid (std::min (int (m), multiProcessorCount * numBlocksPerSM));
202+
355203 VLLM_DISPATCH_HALF_TYPES (
356- input.scalar_type (), " act_and_mul_quant_kernel" , [&] {
357- auto input_ptr = reinterpret_cast <scalar_t const *>(input.data_ptr ());
358- VLLM_DISPATCH_BYTE_TYPES (
359- output.scalar_type (), " fused_act_and_mul_quant_kernel_nvfp4_type" ,
360- [&] {
361- vllm::silu_and_cvt_fp16_to_fp4<scalar_t >
362- <<<grid, block, 0 , stream>>> (
363- m, n, input_ptr, input_sf_ptr,
364- reinterpret_cast <uint32_t *>(output_ptr),
365- reinterpret_cast <uint32_t *>(sf_out));
366- });
204+ input.scalar_type (), " silu_and_mul_nvfp4_quant_kernel" , [&] {
205+ using cuda_type = vllm::CUDATypeConverter<scalar_t >::Type;
206+ auto input_ptr = static_cast <cuda_type const *>(input.data_ptr ());
207+ vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0 , stream>>> (
208+ m, n, input_ptr, input_sf_ptr,
209+ reinterpret_cast <uint32_t *>(output_ptr),
210+ reinterpret_cast <uint32_t *>(sf_out));
367211 });
368212}
0 commit comments