Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions rtp_llm/cpp/devices/rocm_impl/ROCmDevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ ROCmDevice::~ROCmDevice() {
ROCM_CHECK(hipStreamDestroy(assist_stream_));
ROCM_CHECK(hipblasDestroy(hipblas_handle_));
ROCM_CHECK(hipblasLtDestroy(hipblaslt_handle_));
curandstate_buf_.reset();

if (stream_ != nullptr) {
ROCM_CHECK(hipStreamDestroy(stream_));
Expand All @@ -170,8 +169,6 @@ ROCmDevice::~ROCmDevice() {

void ROCmDevice::init() {
DeviceBase::init();
RTP_LLM_LOG_INFO("max batch size: %d", init_params_.max_batch_size);
curandstate_buf_ = allocateBuffer({init_params_.max_batch_size * sizeof(curandState_t)}, {"curandstate"});
}

DeviceProperties ROCmDevice::getDeviceProperties() {
Expand Down
2 changes: 0 additions & 2 deletions rtp_llm/cpp/devices/rocm_impl/ROCmDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ class ROCmDevice: public DeviceBase {
hipStream_t current_stream_ = nullptr;
hipDeviceProp_t device_prop_;

BufferPtr curandstate_buf_; // for sampler use.

rocm::hipblasMMWrapper* hipblasMMWrapperPtr() const {
return hipblas_mm_wrapper_.get();
}
Expand Down
301 changes: 98 additions & 203 deletions rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion rtp_llm/cpp/devices/rocm_impl/test/ROCmSamplerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ TEST_F(CudaSamplerTest, testTopP) {
ASSERT_NEAR(cum_log_probs_host[2], -5.02131, 1e-3);
ASSERT_NEAR(cum_log_probs_host[3], -5.2682, 1e-3);

params.random_seed = nullopt;
for (int i = 0; i < 100; i++) {
device_->sampleGreedy(params);
// printBuffer<int32_t>(*output_token_ids, "output_token_ids");
Expand Down
1 change: 1 addition & 0 deletions rtp_llm/cpp/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ cc_library(
deps = [
":rocm_basic",
":rocm_mla",
"//rtp_llm/cpp/kernels/rocm/sampling:sampling",
],
visibility = ["//visibility:public"],
)
Expand Down
42 changes: 42 additions & 0 deletions rtp_llm/cpp/kernels/rocm/sampling/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
load("//:def.bzl", "rocm_copts")
load("//bazel:arch_select.bzl", "torch_deps")

cc_library(
name = "sampling",
srcs = ["api.cc"],
hdrs = [
"kernel.cuh",
"sampling.h",
"utils.h",
],
deps = [
"//rtp_llm/cpp/kernels:rocm_utils",
"@local_config_rocm//rocm:rocm_headers",
] + torch_deps(),
copts = rocm_copts() + ["-DUSE_ROCM=1"],
visibility = ["//visibility:public"],
)

cc_binary(
name = "bind.so",
srcs = [
"bind.cc",
],
deps = [
":sampling",
"//rtp_llm/cpp/pybind:py_utils",
],
linkshared = True,
linkstatic = False,
)

py_test(
name = "test",
srcs = ["test.py"],
data = [
":bind.so",
],
imports = ["."],
python_version = "PY3",
# tags = ["rocm"],
)
139 changes: 139 additions & 0 deletions rtp_llm/cpp/kernels/rocm/sampling/api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// based on flashinfer 0.4.1 https://github.com/flashinfer-ai/flashinfer/tree/a88349f9f43df74d31d1d52ad5aa20c28824a790
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/hip/HIPGeneratorImpl.h>

#include "sampling.h"
#include "utils.h"
#include "kernel.cuh"

namespace rtp_llm {

std::tuple<uint64_t, uint64_t> get_seed_and_offset(int increment_size, std::optional<at::Generator> generator) {
uint64_t philox_seed, philox_offset;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
generator, at::cuda::detail::getDefaultCUDAGenerator());
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState rng_engine_inputs = gen->philox_cuda_state(increment_size);
philox_seed = rng_engine_inputs.seed_.val;
philox_offset = rng_engine_inputs.offset_.val;
return std::make_tuple(philox_seed, philox_offset);
}

void top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
std::optional<torch::Tensor> maybe_indices,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic, torch::Tensor philox_seed, torch::Tensor philox_offset, uintptr_t stream) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = output.sizes()[0];
unsigned int vocab_size = probs.sizes()[1];
bool has_top_p_arr = maybe_top_p_arr.has_value();

hipSetDevice(probs.get_device());
hipError_t status = sampling::TopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
top_p_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()), reinterpret_cast<hipStream_t>(stream));
TORCH_CHECK(status == hipSuccess, "TopPSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
}

void top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
std::optional<torch::Tensor> maybe_indices,
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, torch::Tensor philox_seed, torch::Tensor philox_offset, uintptr_t stream) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
unsigned int batch_size = output.sizes()[0];
unsigned int vocab_size = probs.sizes()[1];
bool has_top_k_arr = maybe_top_k_arr.has_value();

hipSetDevice(probs.get_device());
hipError_t status = sampling::TopKSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()), reinterpret_cast<hipStream_t>(stream));
TORCH_CHECK(status == hipSuccess, "TopKSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
}

void top_k_top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
std::optional<torch::Tensor> maybe_indices,
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic, torch::Tensor philox_seed,
torch::Tensor philox_offset, uintptr_t stream) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
unsigned int batch_size = output.sizes()[0];
unsigned int vocab_size = probs.sizes()[1];
bool has_top_k_arr = maybe_top_k_arr.has_value();
bool has_top_p_arr = maybe_top_p_arr.has_value();

hipSetDevice(probs.get_device());
hipError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr,
static_cast<int*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
batch_size, top_k_val, top_p_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()),
reinterpret_cast<hipStream_t>(stream));
TORCH_CHECK(status == hipSuccess, "TopKTopPSamplingFromProb failed with error code " + std::string(hipGetErrorString(status)));
}

void top_p_renorm_probs(torch::Tensor probs, torch::Tensor renorm_probs,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, uintptr_t stream) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.sizes()[0];
unsigned int vocab_size = probs.sizes()[1];
bool has_top_p_arr = maybe_top_p_arr.has_value();

hipSetDevice(probs.get_device());
hipError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
top_p_val, vocab_size, reinterpret_cast<hipStream_t>(stream));

TORCH_CHECK(status == hipSuccess, "TopPRenormProb failed with error code " + std::string(hipGetErrorString(status)));
}

void top_k_renorm_probs(torch::Tensor probs, torch::Tensor renorm_probs,
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val, uintptr_t stream) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.sizes()[0];
unsigned int vocab_size = probs.sizes()[1];
bool has_top_k_arr = maybe_top_k_arr.has_value();

hipSetDevice(probs.get_device());
hipError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, reinterpret_cast<hipStream_t>(stream));

TORCH_CHECK(status == hipSuccess, "TopKRenormProb failed with error code " + std::string(hipGetErrorString(status)));
}

}
19 changes: 19 additions & 0 deletions rtp_llm/cpp/kernels/rocm/sampling/bind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <pybind11/pybind11.h>

#include "sampling.h"

namespace py = pybind11;

namespace rtp_llm {

PYBIND11_MODULE(bind, m) {
m.doc() = "sampling c++ api for test";
m.def("get_seed_and_offset", &get_seed_and_offset, py::arg(), py::arg("generator") = std::nullopt, "get_seed_and_offset");
m.def("top_p_renorm_probs", &top_p_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_renorm_probs");
m.def("top_k_renorm_probs", &top_k_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_renorm_probs");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_sampling_from_probs");
m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_sampling_from_probs");
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_top_p_sampling_from_probs");
}

}
Loading
Loading