Skip to content

Commit 702a814

Browse files
committed
fix: seed int64_t
1 parent 712085a commit 702a814

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ inline std::vector<int64_t> bufferShapeToTorchShape(const Buffer& buffer) {
6363
F(TYPE_INT16, torch::kShort) \
6464
F(TYPE_INT32, torch::kInt) \
6565
F(TYPE_INT64, torch::kLong) \
66-
F(TYPE_UINT64, torch::kUInt64) \
6766
F(TYPE_FP16, torch::kHalf) \
6867
F(TYPE_FP32, torch::kFloat) \
6968
F(TYPE_FP64, torch::kDouble) \

rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
3838
auto& top_k = params.top_k;
3939
auto& top_p = params.top_p;
4040
auto& temperature = params.temperature;
41-
// auto& random_seed = params.random_seed;
4241
ROCM_CHECK_VALUE(top_k.size() == batch_size, "top_k.size() != batch_size");
4342
ROCM_CHECK_VALUE(top_p.size() == batch_size, "top_p.size() != batch_size");
4443
ROCM_CHECK_VALUE(temperature.size() == batch_size, "temperature.size() != batch_size");
@@ -116,12 +115,14 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
116115
}
117116

118117
bool deterministic = true;
119-
auto seed_h = allocateBuffer({DataType::TYPE_UINT64, {batch_size}, AllocationType::HOST});
120-
auto offset_h = allocateBuffer({DataType::TYPE_UINT64, {batch_size}, AllocationType::HOST});
118+
auto seed_h = allocateBuffer({DataType::TYPE_INT64, {batch_size}, AllocationType::HOST});
119+
auto offset_h = allocateBuffer({DataType::TYPE_INT64, {batch_size}, AllocationType::HOST});
121120
for (int i = 0; i < batch_size; i++) {
122-
std::tie(seed_h->data<uint64_t>()[i], offset_h->data<uint64_t>()[i]) = params.generator[i].defined() ?
121+
auto [sd, ofst] = params.generator[i].defined() ?
123122
get_seed_and_offset(batch_size * 32, params.generator[i]) :
124123
std::make_pair(0ULL, 0ULL);
124+
seed_h->data<int64_t>()[i] = static_cast<int64_t>(sd);
125+
offset_h->data<int64_t>()[i] = static_cast<int64_t>(ofst);
125126
}
126127
auto seed = Buffer2torchTensor(clone({*seed_h, AllocationType::DEVICE}), false);
127128
auto offset = Buffer2torchTensor(clone({*offset_h, AllocationType::DEVICE}), false);

rtp_llm/cpp/devices/rocm_impl/test/ROCmSamplerTest.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ TEST_F(CudaSamplerTest, testTopP) {
136136
ASSERT_NEAR(cum_log_probs_host[2], -5.02131, 1e-3);
137137
ASSERT_NEAR(cum_log_probs_host[3], -5.2682, 1e-3);
138138

139-
params.random_seed = nullopt;
140139
for (int i = 0; i < 100; i++) {
141140
device_->sampleGreedy(params);
142141
// printBuffer<int32_t>(*output_token_ids, "output_token_ids");

rtp_llm/cpp/kernels/rocm/sampling/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ py_test(
3838
],
3939
imports = ["."],
4040
python_version = "PY3",
41+
# tags = ["rocm"],
4142
)

0 commit comments

Comments
 (0)