Skip to content

Commit 15df633

Browse files
committed
refact: support random_seed on AMD
1 parent afc0d65 commit 15df633

File tree

7 files changed

+172
-78
lines changed

7 files changed

+172
-78
lines changed

rtp_llm/cpp/devices/rocm_impl/ROCmDevice.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ ROCmDevice::~ROCmDevice() {
170170

171171
void ROCmDevice::init() {
172172
DeviceBase::init();
173-
RTP_LLM_LOG_INFO("max batch size: %d", init_params_.max_batch_size);
174-
curandstate_buf_ = allocateBuffer({init_params_.max_batch_size * sizeof(curandState_t)}, {"curandstate"});
173+
int max_batch_size_deprecated = 128;
174+
RTP_LLM_LOG_INFO("max batch size: %d", max_batch_size_deprecated);
175+
curandstate_buf_ = allocateBuffer({max_batch_size_deprecated * sizeof(curandState_t)}, {"curandstate"});
175176
}
176177

177178
DeviceProperties ROCmDevice::getDeviceProperties() {

rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ using SamplerT = float;
2020
// topk should has higher proirity than topp.
2121

2222
GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
23-
bool enable_flashinfer = init_params_.sampler_config.enable_flashinfer_sample_kernel;
24-
const auto& logits = params.logits;
25-
const auto batch_size = logits.shape()[0];
26-
const auto vocab_size_padded = logits.shape()[1];
27-
const auto step = params.step;
23+
bool disable_dprs = std::getenv("DISABLE_ROCM_DPRS") && std::string(std::getenv("DISABLE_ROCM_DPRS")) == "1";
24+
const auto& logits = params.logits;
25+
const auto batch_size = logits.shape()[0];
26+
const auto vocab_size_padded = logits.shape()[1];
27+
const auto step = params.step;
2828
RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
2929
"logits.shape[0] should equal to token_ids.shape[0], but %d vs %d",
3030
batch_size,
@@ -40,7 +40,7 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
4040
auto& top_k = params.top_k;
4141
auto& top_p = params.top_p;
4242
auto& temperature = params.temperature;
43-
auto& random_seed = params.random_seed;
43+
// auto& random_seed = params.random_seed;
4444
ROCM_CHECK_VALUE(top_k.size() == batch_size, "top_k.size() != batch_size");
4545
ROCM_CHECK_VALUE(top_p.size() == batch_size, "top_p.size() != batch_size");
4646
ROCM_CHECK_VALUE(temperature.size() == batch_size, "temperature.size() != batch_size");
@@ -129,24 +129,24 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
129129
// 3. prepare common inputs
130130

131131
// 3.1. setup random seeds
132-
if (random_seed) {
133-
auto& seeds = random_seed.value().get();
134-
if (seeds.size() == 1) {
135-
invokeCurandInitialize(
136-
(curandState_t*)curandstate_buf_->data(), batch_size, seeds.data<uint64_t>()[0], stream_);
137-
} else {
138-
auto random_seeds_buf = allocateBuffer({DataType::TYPE_UINT64, {batch_size}});
139-
RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
140-
"random_seed.size() should equal to batch_size, but %d vs %d",
141-
seeds.size(),
142-
batch_size);
143-
copy({*random_seeds_buf, seeds});
144-
invokeCurandBatchInitialize((curandState_t*)curandstate_buf_->data(),
145-
batch_size,
146-
(unsigned long long*)random_seeds_buf->data(),
147-
stream_);
148-
}
149-
}
132+
// if (random_seed) {
133+
// auto& seeds = random_seed.value().get();
134+
// if (seeds.size() == 1) {
135+
// invokeCurandInitialize(
136+
// (curandState_t*)curandstate_buf_->data(), batch_size, seeds.data<uint64_t>()[0], stream_);
137+
// } else {
138+
// auto random_seeds_buf = allocateBuffer({DataType::TYPE_UINT64, {batch_size}});
139+
// RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
140+
// "random_seed.size() should equal to batch_size, but %d vs %d",
141+
// seeds.size(),
142+
// batch_size);
143+
// copy({*random_seeds_buf, seeds});
144+
// invokeCurandBatchInitialize((curandState_t*)curandstate_buf_->data(),
145+
// batch_size,
146+
// (unsigned long long*)random_seeds_buf->data(),
147+
// stream_);
148+
// }
149+
// }
150150

151151
// 3.2. compute logits penalty
152152
if (std::any_of(
@@ -221,17 +221,32 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
221221
return GreedyOutput{};
222222
}
223223

224-
if (enable_flashinfer) {
224+
if (!disable_dprs) {
225225
const auto batch_size = params.logits.shape()[0];
226226
auto& top_k = params.top_k;
227227
auto& top_p = params.top_p;
228228

229-
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
230-
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
231-
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
232-
torch::TensorOptions options =
233-
torch::TensorOptions(dataTypeToTorchType(probs->type())).device(torch::Device(torch::kCUDA));
234-
bool deterministic = false;
229+
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
230+
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
231+
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
232+
233+
bool deterministic = true;
234+
std::vector<uint64_t> seed_v;
235+
std::vector<uint64_t> offset_v;
236+
for (int i = 0; i < batch_size; i++) {
237+
if (params.generator[i].defined()) {
238+
auto [sd, ofst] = get_seed_and_offset(batch_size * 32, params.generator[i]);
239+
seed_v.push_back(sd);
240+
offset_v.push_back(ofst);
241+
} else {
242+
seed_v.push_back(0);
243+
offset_v.push_back(0);
244+
}
245+
}
246+
auto seed = torch::from_blob(seed_v.data(), {static_cast<long>(batch_size)}, torch::kUInt64).to(torch::kCUDA);
247+
auto offset =
248+
torch::from_blob(offset_v.data(), {static_cast<long>(batch_size)}, torch::kUInt64).to(torch::kCUDA);
249+
235250
bool need_output_all_probs = params.output_all_probs.has_value();
236251
torch::Tensor probs_t = Buffer2torchTensor(probs, false);
237252
torch::Tensor samples_t = Buffer2torchTensor(samples, false).flatten();
@@ -252,7 +267,15 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
252267
}
253268
} else if (std::all_of(
254269
top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, [&](auto t) { return t <= 0; })) {
255-
top_p_sampling_from_probs(probs_t, samples_t, std::nullopt, top_p_t, 1.0, deterministic, 0, 0, reinterpret_cast<uintptr_t>(stream_));
270+
top_p_sampling_from_probs(probs_t,
271+
samples_t,
272+
std::nullopt,
273+
top_p_t,
274+
1.0,
275+
deterministic,
276+
seed,
277+
offset,
278+
reinterpret_cast<uintptr_t>(stream_));
256279
if (need_output_all_probs) {
257280
top_p_renorm_probs(probs_t, output_all_probs_t, top_p_t, 1.0, reinterpret_cast<uintptr_t>(stream_));
258281
}
@@ -263,8 +286,15 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
263286
top_k.data<uint32_t>() + batch_size,
264287
top_k.data<uint32_t>(),
265288
[&](auto t) { return t <= 0 ? 1 << 30 : t; });
266-
top_k_sampling_from_probs(
267-
probs_t, samples_t, std::nullopt, top_k_t, 0, deterministic, 0, 0, reinterpret_cast<uintptr_t>(stream_));
289+
top_k_sampling_from_probs(probs_t,
290+
samples_t,
291+
std::nullopt,
292+
top_k_t,
293+
0,
294+
deterministic,
295+
seed,
296+
offset,
297+
reinterpret_cast<uintptr_t>(stream_));
268298
if (need_output_all_probs) {
269299
top_k_renorm_probs(probs_t, output_all_probs_t, top_k_t, 0, reinterpret_cast<uintptr_t>(stream_));
270300
}
@@ -281,8 +311,8 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
281311
top_p_t,
282312
1.0,
283313
deterministic,
284-
0,
285-
0,
314+
seed,
315+
offset,
286316
reinterpret_cast<uintptr_t>(stream_));
287317
if (need_output_all_probs) {
288318
torch::Tensor temp_t = torch::zeros_like(output_all_probs_t);

rtp_llm/cpp/kernels/rocm/sampling/api.cc

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,29 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
#include <ATen/hip/HIPGeneratorImpl.h>
18+
1719
#include "sampling.h"
1820
#include "utils.h"
1921
#include "kernel.cuh"
2022

2123
namespace rtp_llm {
2224

25+
std::tuple<uint64_t, uint64_t> get_seed_and_offset(int increment_size, std::optional<at::Generator> generator) {
26+
uint64_t philox_seed, philox_offset;
27+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
28+
generator, at::cuda::detail::getDefaultCUDAGenerator());
29+
std::lock_guard<std::mutex> lock(gen->mutex_);
30+
at::PhiloxCudaState rng_engine_inputs = gen->philox_cuda_state(increment_size);
31+
philox_seed = rng_engine_inputs.seed_.val;
32+
philox_offset = rng_engine_inputs.offset_.val;
33+
return std::make_tuple(philox_seed, philox_offset);
34+
}
35+
2336
void top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
2437
std::optional<torch::Tensor> maybe_indices,
2538
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
26-
bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
39+
bool deterministic, torch::Tensor philox_seed, torch::Tensor philox_offset, uintptr_t stream) {
2740
CHECK_INPUT(probs);
2841
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
2942
unsigned int batch_size = output.sizes()[0];
@@ -35,14 +48,14 @@ void top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
3548
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
3649
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
3750
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
38-
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast<hipStream_t>(stream));
51+
top_p_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()), reinterpret_cast<hipStream_t>(stream));
3952
TORCH_CHECK(status == hipSuccess, "TopPSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
4053
}
4154

4255
void top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
4356
std::optional<torch::Tensor> maybe_indices,
4457
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
45-
bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
58+
bool deterministic, torch::Tensor philox_seed, torch::Tensor philox_offset, uintptr_t stream) {
4659
CHECK_INPUT(probs);
4760
CHECK_INPUT(output);
4861
CHECK_DEVICE(output, probs);
@@ -57,16 +70,16 @@ void top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
5770
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
5871
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
5972
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
60-
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast<hipStream_t>(stream));
73+
top_k_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()), reinterpret_cast<hipStream_t>(stream));
6174
TORCH_CHECK(status == hipSuccess, "TopKSamplingFromProbs failed with error code " + std::string(hipGetErrorString(status)));
6275
}
6376

6477
void top_k_top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
6578
std::optional<torch::Tensor> maybe_indices,
6679
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
6780
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
68-
bool deterministic, uint64_t philox_seed,
69-
uint64_t philox_offset, uintptr_t stream) {
81+
bool deterministic, torch::Tensor philox_seed,
82+
torch::Tensor philox_offset, uintptr_t stream) {
7083
CHECK_INPUT(probs);
7184
CHECK_INPUT(output);
7285
CHECK_DEVICE(output, probs);
@@ -84,7 +97,7 @@ void top_k_top_p_sampling_from_probs(torch::Tensor probs, torch::Tensor output,
8497
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr,
8598
static_cast<int*>(output.data_ptr()),
8699
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
87-
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
100+
batch_size, top_k_val, top_p_val, vocab_size, deterministic, static_cast<uint64_t*>(philox_seed.data_ptr()), static_cast<uint64_t*>(philox_offset.data_ptr()),
88101
reinterpret_cast<hipStream_t>(stream));
89102
TORCH_CHECK(status == hipSuccess, "TopKTopPSamplingFromProb failed with error code " + std::string(hipGetErrorString(status)));
90103
}

rtp_llm/cpp/kernels/rocm/sampling/bind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace rtp_llm {
88

99
PYBIND11_MODULE(bind, m) {
1010
m.doc() = "sampling c++ api for test";
11+
m.def("get_seed_and_offset", &get_seed_and_offset, py::arg(), py::arg("generator") = std::nullopt, "get_seed_and_offset");
1112
m.def("top_p_renorm_probs", &top_p_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_renorm_probs");
1213
m.def("top_k_renorm_probs", &top_k_renorm_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_k_renorm_probs");
1314
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg(), py::arg("stream") = 0, "top_p_sampling_from_probs");

0 commit comments

Comments
 (0)