1+ // based on flashinfer 0.4.1 https://github.com/flashinfer-ai/flashinfer/tree/a88349f9f43df74d31d1d52ad5aa20c28824a790
2+ /*
3+ * Copyright (c) 2024 by FlashInfer team.
4+ *
5+ * Licensed under the Apache License, Version 2.0 (the "License");
6+ * you may not use this file except in compliance with the License.
7+ * You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+ #include " sampling.h"
18+ #include " utils.h"
19+ #include " kernel.cuh"
20+
21+ namespace rtp_llm {
22+
23+ void top_p_sampling_from_probs (torch::Tensor probs, torch::Tensor output,
24+ std::optional<torch::Tensor> maybe_indices,
25+ std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
26+ bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
27+ CHECK_INPUT (probs);
28+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
29+ unsigned int batch_size = output.sizes ()[0 ];
30+ unsigned int vocab_size = probs.sizes ()[1 ];
31+ bool has_top_p_arr = maybe_top_p_arr.has_value ();
32+
33+ hipSetDevice (probs.get_device ());
34+ hipError_t status = sampling::TopPSamplingFromProb<float , int >(
35+ static_cast <float *>(probs.data_ptr ()), static_cast <int *>(output.data_ptr ()),
36+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
37+ has_top_p_arr ? static_cast <float *>(maybe_top_p_arr->data_ptr ()) : nullptr , batch_size,
38+ top_p_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast <hipStream_t>(stream));
39+ TORCH_CHECK (status == hipSuccess, " TopPSamplingFromProbs failed with error code " + std::string (hipGetErrorString (status)));
40+ }
41+
42+ void top_k_sampling_from_probs (torch::Tensor probs, torch::Tensor output,
43+ std::optional<torch::Tensor> maybe_indices,
44+ std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
45+ bool deterministic, uint64_t philox_seed, uint64_t philox_offset, uintptr_t stream) {
46+ CHECK_INPUT (probs);
47+ CHECK_INPUT (output);
48+ CHECK_DEVICE (output, probs);
49+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
50+ CHECK_DIM (1 , output); // output: (batch_size)
51+ unsigned int batch_size = output.sizes ()[0 ];
52+ unsigned int vocab_size = probs.sizes ()[1 ];
53+ bool has_top_k_arr = maybe_top_k_arr.has_value ();
54+
55+ hipSetDevice (probs.get_device ());
56+ hipError_t status = sampling::TopKSamplingFromProb<float , int >(
57+ static_cast <float *>(probs.data_ptr ()), static_cast <int *>(output.data_ptr ()),
58+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
59+ has_top_k_arr ? static_cast <float *>(maybe_top_k_arr->data_ptr ()) : nullptr , batch_size,
60+ top_k_val, vocab_size, deterministic, philox_seed, philox_offset, reinterpret_cast <hipStream_t>(stream));
61+ TORCH_CHECK (status == hipSuccess, " TopKSamplingFromProbs failed with error code " + std::string (hipGetErrorString (status)));
62+ }
63+
64+ void top_k_top_p_sampling_from_probs (torch::Tensor probs, torch::Tensor output,
65+ std::optional<torch::Tensor> maybe_indices,
66+ std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
67+ std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
68+ bool deterministic, uint64_t philox_seed,
69+ uint64_t philox_offset, uintptr_t stream) {
70+ CHECK_INPUT (probs);
71+ CHECK_INPUT (output);
72+ CHECK_DEVICE (output, probs);
73+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
74+ CHECK_DIM (1 , output); // output: (batch_size)
75+ unsigned int batch_size = output.sizes ()[0 ];
76+ unsigned int vocab_size = probs.sizes ()[1 ];
77+ bool has_top_k_arr = maybe_top_k_arr.has_value ();
78+ bool has_top_p_arr = maybe_top_p_arr.has_value ();
79+
80+ hipSetDevice (probs.get_device ());
81+ hipError_t status = sampling::TopKTopPSamplingFromProb<float , int >(
82+ static_cast <float *>(probs.data_ptr ()),
83+ has_top_k_arr ? static_cast <int *>(maybe_top_k_arr->data_ptr ()) : nullptr ,
84+ has_top_p_arr ? static_cast <float *>(maybe_top_p_arr->data_ptr ()) : nullptr ,
85+ static_cast <int *>(output.data_ptr ()),
86+ maybe_indices.has_value () ? static_cast <int *>(maybe_indices->data_ptr ()) : nullptr ,
87+ batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
88+ reinterpret_cast <hipStream_t>(stream));
89+ TORCH_CHECK (status == hipSuccess, " TopKTopPSamplingFromProb failed with error code " + std::string (hipGetErrorString (status)));
90+ }
91+
92+ void top_p_renorm_probs (torch::Tensor probs, torch::Tensor renorm_probs,
93+ std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, uintptr_t stream) {
94+ CHECK_INPUT (probs);
95+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
96+ unsigned int batch_size = probs.sizes ()[0 ];
97+ unsigned int vocab_size = probs.sizes ()[1 ];
98+ bool has_top_p_arr = maybe_top_p_arr.has_value ();
99+
100+ hipSetDevice (probs.get_device ());
101+ hipError_t status = sampling::TopPRenormProb<float >(
102+ static_cast <float *>(probs.data_ptr ()), static_cast <float *>(renorm_probs.data_ptr ()),
103+ has_top_p_arr ? static_cast <float *>(maybe_top_p_arr->data_ptr ()) : nullptr , batch_size,
104+ top_p_val, vocab_size, reinterpret_cast <hipStream_t>(stream));
105+
106+ TORCH_CHECK (status == hipSuccess, " TopPRenormProb failed with error code " + std::string (hipGetErrorString (status)));
107+ }
108+
109+ void top_k_renorm_probs (torch::Tensor probs, torch::Tensor renorm_probs,
110+ std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val, uintptr_t stream) {
111+ CHECK_INPUT (probs);
112+ CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
113+ unsigned int batch_size = probs.sizes ()[0 ];
114+ unsigned int vocab_size = probs.sizes ()[1 ];
115+ bool has_top_k_arr = maybe_top_k_arr.has_value ();
116+
117+ hipSetDevice (probs.get_device ());
118+ hipError_t status = sampling::TopKRenormProb<float >(
119+ static_cast <float *>(probs.data_ptr ()), static_cast <float *>(renorm_probs.data_ptr ()),
120+ has_top_k_arr ? static_cast <int *>(maybe_top_k_arr->data_ptr ()) : nullptr , batch_size,
121+ top_k_val, vocab_size, reinterpret_cast <hipStream_t>(stream));
122+
123+ TORCH_CHECK (status == hipSuccess, " TopKRenormProb failed with error code " + std::string (hipGetErrorString (status)));
124+ }
125+
126+ }
0 commit comments