Skip to content

Commit afc0d65

Browse files
committed
feat: support flashinfer sample kernel on AMD
1 parent bf3be40 commit afc0d65

File tree

9 files changed

+1793
-0
lines changed

9 files changed

+1793
-0
lines changed

rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "rtp_llm/cpp/kernels/sampling_topp_kernels.h"
55
#include "rtp_llm/cpp/kernels/sampling_penalty_kernels.h"
66
#include "rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h"
7+
#include "rtp_llm/cpp/kernels/rocm/sampling/sampling.h"
78

89
using namespace std;
910

@@ -19,6 +20,7 @@ using SamplerT = float;
1920
// topk should has higher proirity than topp.
2021

2122
GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
23+
bool enable_flashinfer = init_params_.sampler_config.enable_flashinfer_sample_kernel;
2224
const auto& logits = params.logits;
2325
const auto batch_size = logits.shape()[0];
2426
const auto vocab_size_padded = logits.shape()[1];
@@ -219,6 +221,81 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
219221
return GreedyOutput{};
220222
}
221223

224+
if (enable_flashinfer) {
225+
const auto batch_size = params.logits.shape()[0];
226+
auto& top_k = params.top_k;
227+
auto& top_p = params.top_p;
228+
229+
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
230+
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
231+
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
232+
torch::TensorOptions options =
233+
torch::TensorOptions(dataTypeToTorchType(probs->type())).device(torch::Device(torch::kCUDA));
234+
bool deterministic = false;
235+
bool need_output_all_probs = params.output_all_probs.has_value();
236+
torch::Tensor probs_t = Buffer2torchTensor(probs, false);
237+
torch::Tensor samples_t = Buffer2torchTensor(samples, false).flatten();
238+
torch::Tensor top_k_t = Buffer2torchTensor(top_k, false);
239+
torch::Tensor top_p_t = Buffer2torchTensor(top_p, false);
240+
torch::Tensor output_all_probs_t;
241+
if (need_output_all_probs) {
242+
output_all_probs_t = Buffer2torchTensor(params.output_all_probs.value().get(), false);
243+
}
244+
std::transform(top_p.data<float>(), top_p.data<float>() + batch_size, top_p.data<float>(), [&](auto t) {
245+
return std::abs(t) < 1e-7 ? 1.0 : t;
246+
});
247+
if (std::all_of(top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, [&](auto t) { return t == 1; })) {
248+
torch::Tensor selected_tokens = torch::argmax(probs_t, -1, /*keepdim=*/false);
249+
samples_t.copy_(selected_tokens);
250+
if (need_output_all_probs) {
251+
top_k_renorm_probs(probs_t, output_all_probs_t, top_k_t, 0, reinterpret_cast<uintptr_t>(stream_));
252+
}
253+
} else if (std::all_of(
254+
top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, [&](auto t) { return t <= 0; })) {
255+
top_p_sampling_from_probs(probs_t, samples_t, std::nullopt, top_p_t, 1.0, deterministic, 0, 0, reinterpret_cast<uintptr_t>(stream_));
256+
if (need_output_all_probs) {
257+
top_p_renorm_probs(probs_t, output_all_probs_t, top_p_t, 1.0, reinterpret_cast<uintptr_t>(stream_));
258+
}
259+
} else if (std::all_of(top_p.data<float>(), top_p.data<float>() + batch_size, [&](auto t) {
260+
return std::abs(t - 1.0f) < 1e-7;
261+
})) {
262+
std::transform(top_k.data<uint32_t>(),
263+
top_k.data<uint32_t>() + batch_size,
264+
top_k.data<uint32_t>(),
265+
[&](auto t) { return t <= 0 ? 1 << 30 : t; });
266+
top_k_sampling_from_probs(
267+
probs_t, samples_t, std::nullopt, top_k_t, 0, deterministic, 0, 0, reinterpret_cast<uintptr_t>(stream_));
268+
if (need_output_all_probs) {
269+
top_k_renorm_probs(probs_t, output_all_probs_t, top_k_t, 0, reinterpret_cast<uintptr_t>(stream_));
270+
}
271+
} else {
272+
std::transform(top_k.data<uint32_t>(),
273+
top_k.data<uint32_t>() + batch_size,
274+
top_k.data<uint32_t>(),
275+
[&](auto t) { return t <= 0 ? 1 << 30 : t; });
276+
top_k_top_p_sampling_from_probs(probs_t,
277+
samples_t,
278+
std::nullopt,
279+
top_k_t,
280+
0,
281+
top_p_t,
282+
1.0,
283+
deterministic,
284+
0,
285+
0,
286+
reinterpret_cast<uintptr_t>(stream_));
287+
if (need_output_all_probs) {
288+
torch::Tensor temp_t = torch::zeros_like(output_all_probs_t);
289+
top_k_renorm_probs(probs_t, temp_t, top_k_t, 1.0, reinterpret_cast<uintptr_t>(stream_));
290+
top_p_renorm_probs(temp_t, output_all_probs_t, top_p_t, 1.0, reinterpret_cast<uintptr_t>(stream_));
291+
}
292+
}
293+
auto output_tokens = transpose({*transposed_tokens});
294+
copy({params.token_ids, *output_tokens});
295+
check_cuda_error();
296+
return GreedyOutput{};
297+
}
298+
222299
// 4. run sampling
223300
// 4.1 run top_k
224301
invokeSetupTopKRuntimeArgs(batch_size,

rtp_llm/cpp/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ cc_library(
235235
deps = [
236236
":rocm_basic",
237237
":rocm_mla",
238+
"//rtp_llm/cpp/kernels/rocm/sampling:sampling",
238239
],
239240
visibility = ["//visibility:public"],
240241
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
load("//:def.bzl", "rocm_copts")
2+
load("//bazel:arch_select.bzl", "torch_deps")
3+
4+
cc_library(
5+
name = "sampling",
6+
srcs = ["api.cc"],
7+
hdrs = [
8+
"kernel.cuh",
9+
"sampling.h",
10+
"utils.h",
11+
],
12+
deps = [
13+
"//rtp_llm/cpp/kernels:rocm_utils",
14+
"@local_config_rocm//rocm:rocm_headers",
15+
] + torch_deps(),
16+
copts = rocm_copts() + ["-DUSE_ROCM=1"],
17+
visibility = ["//visibility:public"],
18+
)
19+
20+
cc_binary(
21+
name = "bind.so",
22+
srcs = [
23+
"bind.cc",
24+
],
25+
deps = [
26+
":sampling",
27+
"//rtp_llm/cpp/pybind:py_utils",
28+
],
29+
linkshared = True,
30+
linkstatic = False,
31+
)
32+
33+
py_test(
34+
name = "test",
35+
srcs = ["test.py"],
36+
data = [
37+
":bind.so",
38+
],
39+
imports = ["."],
40+
python_version = "PY3",
41+
)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// based on flashinfer 0.4.1 https://github.com/flashinfer-ai/flashinfer/tree/a88349f9f43df74d31d1d52ad5aa20c28824a790
2+
/*
3+
* Copyright (c) 2024 by FlashInfer team.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#include "sampling.h"
18+
#include "utils.h"
19+
#include "kernel.cuh"
20+
21+
namespace rtp_llm {
22+
23+
void top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
24+
std::optional<torch::Tensor> maybe_indices,
25+
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
26+
bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
27+
CHECK_INPUT(probs);
28+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
29+
unsigned int batch_size = output.sizes()[0];
30+
unsigned int vocab_size = probs.sizes()[1];
31+
bool has_top_p_arr = maybe_top_p_arr.has_value();
32+
33+
hipSetDevice(probs.get_device());
34+
hipError_t status = sampling::TopPSamplingFromProb<float, int>(
35+
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
36+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
37+
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
38+
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast<hipStream_t>(stream));
39+
TORCH_CHECK(status == hipSuccess, "TopPSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
40+
}
41+
42+
void top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
43+
std::optional<torch::Tensor> maybe_indices,
44+
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
45+
bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
46+
CHECK_INPUT(probs);
47+
CHECK_INPUT(output);
48+
CHECK_DEVICE(output, probs);
49+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
50+
CHECK_DIM(1, output); // output: (batch_size)
51+
unsigned int batch_size = output.sizes()[0];
52+
unsigned int vocab_size = probs.sizes()[1];
53+
bool has_top_k_arr = maybe_top_k_arr.has_value();
54+
55+
hipSetDevice(probs.get_device());
56+
hipError_t status = sampling::TopKSamplingFromProb<float, int>(
57+
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
58+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
59+
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
60+
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast<hipStream_t>(stream));
61+
TORCH_CHECK(status == hipSuccess, "TopKSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
62+
}
63+
64+
void top_k_top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
65+
std::optional<torch::Tensor> maybe_indices,
66+
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
67+
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
68+
bool deterministic, uint64_t philox_seed,
69+
uint64_t philox_offset, uintptr_t stream) {
70+
CHECK_INPUT(probs);
71+
CHECK_INPUT(output);
72+
CHECK_DEVICE(output, probs);
73+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
74+
CHECK_DIM(1, output); // output: (batch_size)
75+
unsigned int batch_size = output.sizes()[0];
76+
unsigned int vocab_size = probs.sizes()[1];
77+
bool has_top_k_arr = maybe_top_k_arr.has_value();
78+
bool has_top_p_arr = maybe_top_p_arr.has_value();
79+
80+
hipSetDevice(probs.get_device());
81+
hipError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
82+
static_cast<float*>(probs.data_ptr()),
83+
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr,
84+
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr,
85+
static_cast<int*>(output.data_ptr()),
86+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
87+
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
88+
reinterpret_cast<hipStream_t>(stream));
89+
TORCH_CHECK(status == hipSuccess, "TopKTopPSamplingFromProb failed with error code " + std::string(hipGetErrorString(status)));
90+
}
91+
92+
void top_p_renorm_probs(torch::Tensor probs, torch::Tensor renorm_probs,
93+
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, uintptr_t stream) {
94+
CHECK_INPUT(probs);
95+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
96+
unsigned int batch_size = probs.sizes()[0];
97+
unsigned int vocab_size = probs.sizes()[1];
98+
bool has_top_p_arr = maybe_top_p_arr.has_value();
99+
100+
hipSetDevice(probs.get_device());
101+
hipError_t status = sampling::TopPRenormProb<float>(
102+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
103+
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
104+
top_p_val, vocab_size, reinterpret_cast<hipStream_t>(stream));
105+
106+
TORCH_CHECK(status == hipSuccess, "TopPRenormProb failed with error code " + std::string(hipGetErrorString(status)));
107+
}
108+
109+
void top_k_renorm_probs(torch::Tensor probs, torch::Tensor renorm_probs,
110+
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val, uintptr_t stream) {
111+
CHECK_INPUT(probs);
112+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
113+
unsigned int batch_size = probs.sizes()[0];
114+
unsigned int vocab_size = probs.sizes()[1];
115+
bool has_top_k_arr = maybe_top_k_arr.has_value();
116+
117+
hipSetDevice(probs.get_device());
118+
hipError_t status = sampling::TopKRenormProb<float>(
119+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
120+
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
121+
top_k_val, vocab_size, reinterpret_cast<hipStream_t>(stream));
122+
123+
TORCH_CHECK(status == hipSuccess, "TopKRenormProb failed with error code " + std::string(hipGetErrorString(status)));
124+
}
125+
126+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include <pybind11/pybind11.h>
2+
3+
#include "sampling.h"
4+
5+
namespace py = pybind11;
6+
7+
namespace rtp_llm {
8+
9+
PYBIND11_MODULE(bind, m) {
10+
m.doc() = "sampling c++ api for test";
11+
m.def("top_p_renorm_probs", &top_p_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_renorm_probs");
12+
m.def("top_k_renorm_probs", &top_k_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_renorm_probs");
13+
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_sampling_from_probs");
14+
m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_sampling_from_probs");
15+
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_top_p_sampling_from_probs");
16+
}
17+
18+
}

0 commit comments

Comments
 (0)