Skip to content

Commit e2752ac

Browse files
committed
test:profiling
1 parent 15df633 commit e2752ac

File tree

4 files changed

+57
-35
lines changed

4 files changed

+57
-35
lines changed

rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ inline std::vector<int64_t> bufferShapeToTorchShape(const Buffer& buffer) {
6363
F(TYPE_INT16, torch::kShort) \
6464
F(TYPE_INT32, torch::kInt) \
6565
F(TYPE_INT64, torch::kLong) \
66+
F(TYPE_UINT64, torch::kUInt64) \
6667
F(TYPE_FP16, torch::kHalf) \
6768
F(TYPE_FP32, torch::kFloat) \
6869
F(TYPE_FP64, torch::kDouble) \

rtp_llm/cpp/devices/rocm_impl/ROCmSampleOp.cc

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ namespace rtp_llm {
1212

1313
using SamplerT = float;
1414

15+
void _saveTorchDataTofile(const torch::Tensor& tensor, const std::string& fileName) {
16+
auto tensor_cpu = tensor.contiguous().cpu();
17+
auto pickled = torch::pickle_save(tensor_cpu);
18+
std::ofstream fout(fileName, std::ios::out | std::ios::binary);
19+
fout.write(pickled.data(), pickled.size());
20+
fout.close();
21+
}
22+
1523
// batch sampling explained:
1624
// topk = [4, 0, 4]. topp = [0.0, 0.5, 0.5]
1725
// then topk_decode handles [4, x, 4 + 0.5]
@@ -226,26 +234,22 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
226234
auto& top_k = params.top_k;
227235
auto& top_p = params.top_p;
228236

229-
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
230-
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
231-
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
232-
233237
bool deterministic = true;
234-
std::vector<uint64_t> seed_v;
235-
std::vector<uint64_t> offset_v;
238+
auto seed_h = allocateBuffer({DataType::TYPE_UINT64, {batch_size}, AllocationType::HOST});
239+
auto offset_h = allocateBuffer({DataType::TYPE_UINT64, {batch_size}, AllocationType::HOST});
236240
for (int i = 0; i < batch_size; i++) {
237-
if (params.generator[i].defined()) {
238-
auto [sd, ofst] = get_seed_and_offset(batch_size * 32, params.generator[i]);
239-
seed_v.push_back(sd);
240-
offset_v.push_back(ofst);
241-
} else {
242-
seed_v.push_back(0);
243-
offset_v.push_back(0);
244-
}
241+
std::tie(seed_h->data<uint64_t>()[i], offset_h->data<uint64_t>()[i]) = params.generator[i].defined() ?
242+
get_seed_and_offset(batch_size * 32, params.generator[i]) :
243+
std::make_pair(0ULL, 0ULL);
245244
}
246-
auto seed = torch::from_blob(seed_v.data(), {static_cast<long>(batch_size)}, torch::kUInt64).to(torch::kCUDA);
247-
auto offset =
248-
torch::from_blob(offset_v.data(), {static_cast<long>(batch_size)}, torch::kUInt64).to(torch::kCUDA);
245+
auto seed_d = clone({*seed_h, AllocationType::DEVICE});
246+
auto offset_d = clone({*offset_h, AllocationType::DEVICE});
247+
auto seed = Buffer2torchTensor(seed_d, false);
248+
auto offset = Buffer2torchTensor(offset_d, false);
249+
250+
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
251+
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
252+
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
249253

250254
bool need_output_all_probs = params.output_all_probs.has_value();
251255
torch::Tensor probs_t = Buffer2torchTensor(probs, false);
@@ -267,6 +271,11 @@ GreedyOutput ROCmDevice::sampleGreedy(const GreedyParams& params) {
267271
}
268272
} else if (std::all_of(
269273
top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, [&](auto t) { return t <= 0; })) {
274+
static int fwd = 0;
275+
++fwd;
276+
if (std::getenv("XBJ_DUMP_PROBS")) {
277+
_saveTorchDataTofile(probs_t, std::string(std::getenv("XBJ_DUMP_PROBS")) + "/probs" + std::to_string(fwd) + ".pt");
278+
}
270279
top_p_sampling_from_probs(probs_t,
271280
samples_t,
272281
std::nullopt,

rtp_llm/cpp/kernels/rocm/sampling/kernel.cuh

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,7 @@ using namespace hipcub;
5858
__VA_ARGS__ \
5959
}
6060

61-
#define DISPATCH_SOFTMAX_CACHE_INPUT(cache_input, CACHE_INPUT, ...) \
62-
if (cache_input) { \
63-
constexpr bool CACHE_INPUT = true; \
64-
__VA_ARGS__ \
65-
} else { \
66-
constexpr bool CACHE_INPUT = false; \
67-
__VA_ARGS__ \
68-
}
61+
#define VEC_BYTES 64
6962

7063
constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
7164
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
@@ -650,7 +643,7 @@ hipError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* to
650643
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
651644
bool deterministic, uint64_t* philox_seed, uint64_t* philox_offset,
652645
hipStream_t stream = 0) {
653-
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
646+
const uint32_t vec_size = std::gcd(VEC_BYTES / sizeof(T), d);
654647

655648
auto compute_capacity = GetCudaComputeCapability();
656649
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
@@ -678,7 +671,7 @@ hipError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* to
678671
uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic,
679672
uint64_t* philox_seed, uint64_t* philox_offset,
680673
hipStream_t stream = 0) {
681-
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
674+
const uint32_t vec_size = std::gcd(VEC_BYTES / sizeof(T), d);
682675

683676
auto compute_capacity = GetCudaComputeCapability();
684677
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
@@ -708,7 +701,7 @@ hipError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, I
708701
T top_p_val, uint32_t d, bool deterministic,
709702
uint64_t* philox_seed, uint64_t* philox_offset,
710703
hipStream_t stream = 0) {
711-
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
704+
const uint32_t vec_size = std::gcd(VEC_BYTES / sizeof(T), d);
712705

713706
auto compute_capacity = GetCudaComputeCapability();
714707
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
@@ -1053,7 +1046,7 @@ template <typename DType>
10531046
hipError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr,
10541047
uint32_t batch_size, float top_p_val, uint32_t d,
10551048
hipStream_t stream = 0) {
1056-
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
1049+
const uint32_t vec_size = std::gcd(VEC_BYTES / sizeof(DType), d);
10571050

10581051
auto compute_capacity = GetCudaComputeCapability();
10591052
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
@@ -1075,7 +1068,7 @@ template <typename DType, typename IdType>
10751068
hipError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
10761069
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
10771070
hipStream_t stream = 0) {
1078-
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
1071+
const uint32_t vec_size = std::gcd(VEC_BYTES / sizeof(DType), d);
10791072

10801073
auto compute_capacity = GetCudaComputeCapability();
10811074
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {

rtp_llm/cpp/kernels/rocm/sampling/test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def test_top_p_sampling(batch_size, vocab_size, p):
170170
realdata = Path(realdata)
171171
assert realdata.is_file()
172172
normalized_prob = torch.load(realdata, weights_only=True).to("cuda:0")
173+
batch_size = normalized_prob.shape[0]
174+
vocab_size = normalized_prob.shape[1]
173175

174176
info["prob"] = normalized_prob.cpu().numpy().tolist()
175177
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
@@ -183,10 +185,14 @@ def test_top_p_sampling(batch_size, vocab_size, p):
183185
file = Path(file)
184186
with file.open("w") as f:
185187
json.dump(info, f, ensure_ascii=False, indent=4)
186-
num_trials = 1000
188+
num_trials = 10
187189
info["out"] = []
190+
samples = torch.empty(batch_size, dtype=torch.int32, device="cuda:0")
191+
192+
start_event = torch.cuda.Event(enable_timing=True)
193+
end_event = torch.cuda.Event(enable_timing=True)
194+
start_event.record()
188195
for _ in range(num_trials):
189-
samples = torch.empty(batch_size, dtype=torch.int32, device="cuda:0")
190196
top_p_sampling_from_probs(
191197
normalized_prob,
192198
samples,
@@ -197,9 +203,15 @@ def test_top_p_sampling(batch_size, vocab_size, p):
197203
torch.zeros(batch_size, dtype=torch.uint64, device="cuda:0"),
198204
torch.zeros(batch_size, dtype=torch.uint64, device="cuda:0"),
199205
)
200-
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
201-
assert torch.all(mask[torch.arange(batch_size), samples] == 1)
202-
info["out"].append(samples.cpu().numpy().tolist())
206+
# assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
207+
# assert torch.all(mask[torch.arange(batch_size), samples] == 1)
208+
# info["out"].append(samples.cpu().numpy().tolist())
209+
end_event.record()
210+
torch.cuda.synchronize()
211+
elapsed_time = start_event.elapsed_time(end_event) / num_trials
212+
print(f"elapsed_time: {elapsed_time} ms")
213+
return elapsed_time
214+
203215
if file:
204216
with file.open("w") as f:
205217
json.dump(info, f, ensure_ascii=False, indent=4)
@@ -376,4 +388,11 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k):
376388

377389

378390
if __name__ == "__main__":
391+
for i in range(1, 377):
392+
os.environ["REALDATA"] = f"/home/xiebaijie.xbj/qwen-vl/probs/probs{i}.pt"
393+
rt = test_top_p_sampling(1, 1, 0.95)
394+
if rt > 0.5:
395+
exit(1)
396+
print("no found")
397+
exit(1)
379398
exit(pytest.main([__file__, "-s", "-v"]))

0 commit comments

Comments
 (0)