Skip to content

Commit c7ba1c9

Browse files
spcypptmeta-codesync[bot]
authored andcommitted
Add proper check for generic_histogram_binning_calibration_by_feature_cpu (#5042)
Summary: Pull Request resolved: #5042 X-link: https://github.com/facebookresearch/FBGEMM/pull/2052 Add check to catch any index overflow Reviewed By: ionuthristodorescu Differential Revision: D85206442 fbshipit-source-id: b6e38949344e7c450263933dbe1add2bad057cef
1 parent 06285d6 commit c7ba1c9

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,19 @@ void _to_dense_representation(
4242
const int64_t num_lengths,
4343
const SegmentValueType* const segment_value_data,
4444
const SegmentLengthType* const segment_lengths_data,
45-
SegmentValueType* const dense_segment_value_data) {
45+
SegmentValueType* const dense_segment_value_data,
46+
const int64_t num_segment_value) {
4647
int k = 0;
4748
for (const auto i : c10::irange(num_lengths)) {
4849
if (segment_lengths_data[i] == 1) {
4950
// Add 1 to distinguish between 0 inserted by densification vs. original
5051
// value.
52+
TORCH_CHECK(
53+
k < num_segment_value,
54+
"k should be less than num_segment_value ",
55+
num_segment_value,
56+
" but found k = ",
57+
k);
5158
dense_segment_value_data[i] = segment_value_data[k] + 1;
5259
} else {
5360
dense_segment_value_data[i] = 0;
@@ -2504,7 +2511,8 @@ std::tuple<Tensor, Tensor> histogram_binning_calibration_by_feature_cpu(
25042511
segment_lengths.numel(),
25052512
segment_value.data_ptr<segment_value_t>(),
25062513
segment_lengths.data_ptr<segment_length_t>(),
2507-
dense_segment_value.data_ptr<segment_value_t>());
2514+
dense_segment_value.data_ptr<segment_value_t>(),
2515+
segment_value.numel());
25082516
});
25092517
});
25102518

@@ -2613,6 +2621,16 @@ std::tuple<Tensor, Tensor> generic_histogram_binning_calibration_by_feature_cpu(
26132621
// dense_segment_value is used as a temporary storage.
26142622
Tensor dense_segment_value =
26152623
at::empty({logit.numel()}, segment_value.options());
2624+
2625+
// _to_dense_representation will access dense_segment_value[i] where i <=
2626+
// num_length, so num_length should be within the range of dense_segment_value
2627+
TORCH_CHECK(
2628+
segment_lengths.numel() <= dense_segment_value.numel(),
2629+
"segment_lengths numel (num_length) should be less than dense_segment_value numel of ",
2630+
dense_segment_value.numel(),
2631+
" but found num_length = ",
2632+
segment_lengths.numel());
2633+
26162634
AT_DISPATCH_INDEX_TYPES(
26172635
segment_value.scalar_type(), "to_dense_representation_cpu_wrapper", [&] {
26182636
using segment_value_t = index_t;
@@ -2623,7 +2641,8 @@ std::tuple<Tensor, Tensor> generic_histogram_binning_calibration_by_feature_cpu(
26232641
segment_lengths.numel(),
26242642
segment_value.data_ptr<segment_value_t>(),
26252643
segment_lengths.data_ptr<segment_length_t>(),
2626-
dense_segment_value.data_ptr<segment_value_t>());
2644+
dense_segment_value.data_ptr<segment_value_t>(),
2645+
segment_value.numel());
26272646
});
26282647
});
26292648

0 commit comments

Comments
 (0)