@@ -20,11 +20,11 @@ using SamplerT = float;
2020// topk should has higher proirity than topp.
2121
2222GreedyOutput ROCmDevice::sampleGreedy (const GreedyParams& params) {
23- bool enable_flashinfer = init_params_. sampler_config . enable_flashinfer_sample_kernel ;
24- const auto & logits = params.logits ;
25- const auto batch_size = logits.shape ()[0 ];
26- const auto vocab_size_padded = logits.shape ()[1 ];
27- const auto step = params.step ;
23+ bool disable_dprs = std::getenv ( " DISABLE_ROCM_DPRS " ) && std::string ( std::getenv ( " DISABLE_ROCM_DPRS " )) == " 1 " ;
24+ const auto & logits = params.logits ;
25+ const auto batch_size = logits.shape ()[0 ];
26+ const auto vocab_size_padded = logits.shape ()[1 ];
27+ const auto step = params.step ;
2828 RUNTIME_ASSERT_OP_ARG (batch_size == params.token_ids .shape ()[0 ],
2929 " logits.shape[0] should equal to token_ids.shape[0], but %d vs %d" ,
3030 batch_size,
@@ -40,7 +40,7 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
4040 auto & top_k = params.top_k ;
4141 auto & top_p = params.top_p ;
4242 auto & temperature = params.temperature ;
43- auto & random_seed = params.random_seed ;
43+ // auto& random_seed = params.random_seed;
4444 ROCM_CHECK_VALUE (top_k.size () == batch_size, " top_k.size() != batch_size" );
4545 ROCM_CHECK_VALUE (top_p.size () == batch_size, " top_p.size() != batch_size" );
4646 ROCM_CHECK_VALUE (temperature.size () == batch_size, " temperature.size() != batch_size" );
@@ -129,24 +129,24 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
129129 // 3. prepare common inputs
130130
131131 // 3.1. setup random seeds
132- if (random_seed) {
133- auto & seeds = random_seed.value ().get ();
134- if (seeds.size () == 1 ) {
135- invokeCurandInitialize (
136- (curandState_t*)curandstate_buf_->data (), batch_size, seeds.data <uint64_t >()[0 ], stream_);
137- } else {
138- auto random_seeds_buf = allocateBuffer ({DataType::TYPE_UINT64, {batch_size}});
139- RUNTIME_ASSERT_OP_ARG ((seeds.size () == batch_size),
140- " random_seed.size() should equal to batch_size, but %d vs %d" ,
141- seeds.size (),
142- batch_size);
143- copy ({*random_seeds_buf, seeds});
144- invokeCurandBatchInitialize ((curandState_t*)curandstate_buf_->data (),
145- batch_size,
146- (unsigned long long *)random_seeds_buf->data (),
147- stream_);
148- }
149- }
132+ // if (random_seed) {
133+ // auto& seeds = random_seed.value().get();
134+ // if (seeds.size() == 1) {
135+ // invokeCurandInitialize(
136+ // (curandState_t*)curandstate_buf_->data(), batch_size, seeds.data<uint64_t>()[0], stream_);
137+ // } else {
138+ // auto random_seeds_buf = allocateBuffer({DataType::TYPE_UINT64, {batch_size}});
139+ // RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
140+ // "random_seed.size() should equal to batch_size, but %d vs %d",
141+ // seeds.size(),
142+ // batch_size);
143+ // copy({*random_seeds_buf, seeds});
144+ // invokeCurandBatchInitialize((curandState_t*)curandstate_buf_->data(),
145+ // batch_size,
146+ // (unsigned long long*)random_seeds_buf->data(),
147+ // stream_);
148+ // }
149+ // }
150150
151151 // 3.2. compute logits penalty
152152 if (std::any_of (
@@ -221,17 +221,32 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
221221 return GreedyOutput{};
222222 }
223223
224- if (enable_flashinfer ) {
224+ if (!disable_dprs ) {
225225 const auto batch_size = params.logits .shape ()[0 ];
226226 auto & top_k = params.top_k ;
227227 auto & top_p = params.top_p ;
228228
229- auto logits_ref = params.logits .slice (0 , params.logits .shape ()[0 ]);
230- auto probs = softmax ({logits_ref, std::nullopt , std::nullopt , 1 .0f , DataType::TYPE_INVALID, std::nullopt });
231- auto samples = transposed_tokens->view (transposed_tokens->shape ()[0 ] - 1 , 1 );
232- torch::TensorOptions options =
233- torch::TensorOptions (dataTypeToTorchType (probs->type ())).device (torch::Device (torch::kCUDA ));
234- bool deterministic = false ;
229+ auto logits_ref = params.logits .slice (0 , params.logits .shape ()[0 ]);
230+ auto probs = softmax ({logits_ref, std::nullopt , std::nullopt , 1 .0f , DataType::TYPE_INVALID, std::nullopt });
231+ auto samples = transposed_tokens->view (transposed_tokens->shape ()[0 ] - 1 , 1 );
232+
233+ bool deterministic = true ;
234+ std::vector<uint64_t > seed_v;
235+ std::vector<uint64_t > offset_v;
236+ for (int i = 0 ; i < batch_size; i++) {
237+ if (params.generator [i].defined ()) {
238+ auto [sd, ofst] = get_seed_and_offset (batch_size * 32 , params.generator [i]);
239+ seed_v.push_back (sd);
240+ offset_v.push_back (ofst);
241+ } else {
242+ seed_v.push_back (0 );
243+ offset_v.push_back (0 );
244+ }
245+ }
246+ auto seed = torch::from_blob (seed_v.data (), {static_cast <long >(batch_size)}, torch::kUInt64 ).to (torch::kCUDA );
247+ auto offset =
248+ torch::from_blob (offset_v.data (), {static_cast <long >(batch_size)}, torch::kUInt64 ).to (torch::kCUDA );
249+
235250 bool need_output_all_probs = params.output_all_probs .has_value ();
236251 torch::Tensor probs_t = Buffer2torchTensor (probs, false );
237252 torch::Tensor samples_t = Buffer2torchTensor (samples, false ).flatten ();
@@ -252,7 +267,15 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
252267 }
253268 } else if (std::all_of (
254269 top_k.data <uint32_t >(), top_k.data <uint32_t >() + batch_size, [&](auto t) { return t <= 0 ; })) {
255- top_p_sampling_from_probs (probs_t , samples_t , std::nullopt , top_p_t , 1.0 , deterministic, 0 , 0 , reinterpret_cast <uintptr_t >(stream_));
270+ top_p_sampling_from_probs (probs_t ,
271+ samples_t ,
272+ std::nullopt ,
273+ top_p_t ,
274+ 1.0 ,
275+ deterministic,
276+ seed,
277+ offset,
278+ reinterpret_cast <uintptr_t >(stream_));
256279 if (need_output_all_probs) {
257280 top_p_renorm_probs (probs_t , output_all_probs_t , top_p_t , 1.0 , reinterpret_cast <uintptr_t >(stream_));
258281 }
@@ -263,8 +286,15 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
263286 top_k.data <uint32_t >() + batch_size,
264287 top_k.data <uint32_t >(),
265288 [&](auto t) { return t <= 0 ? 1 << 30 : t; });
266- top_k_sampling_from_probs (
267- probs_t , samples_t , std::nullopt , top_k_t , 0 , deterministic, 0 , 0 , reinterpret_cast <uintptr_t >(stream_));
289+ top_k_sampling_from_probs (probs_t ,
290+ samples_t ,
291+ std::nullopt ,
292+ top_k_t ,
293+ 0 ,
294+ deterministic,
295+ seed,
296+ offset,
297+ reinterpret_cast <uintptr_t >(stream_));
268298 if (need_output_all_probs) {
269299 top_k_renorm_probs (probs_t , output_all_probs_t , top_k_t , 0 , reinterpret_cast <uintptr_t >(stream_));
270300 }
@@ -281,8 +311,8 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
281311 top_p_t ,
282312 1.0 ,
283313 deterministic,
284- 0 ,
285- 0 ,
314+ seed ,
315+ offset ,
286316 reinterpret_cast <uintptr_t >(stream_));
287317 if (need_output_all_probs) {
288318 torch::Tensor temp_t = torch::zeros_like (output_all_probs_t );
0 commit comments