Skip to content

Commit 05576a1

Browse files
[GPU] Implement padding extension for SDPA shape canonicalization (#32402)
### Details: - Update SDPA shape canonicalization process by adding support for padding extension when transforming tensor shapes to target 4D ranks - Extend padding information when input rank differs from target rank (default 4D) - Inserts padding dimension at the `num_heads_dim` position (index 1) ### Tickets: - [CVS-173728](https://jira.devtools.intel.com/browse/CVS-173728)
1 parent 0d4e61f commit 05576a1

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/sdpa_base.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,21 @@ kernel_impl_params SDPABase::static_canonicalize_shapes(const kernel_impl_params
348348
return pshape;
349349
};
350350

351+
auto extend_padding_to_rank_in_num_heads_dim = [](const padding& input_padding, size_t input_rank, size_t target_rank = 4) {
352+
if (input_rank == target_rank) {
353+
return input_padding;
354+
}
355+
356+
std::vector<ov::Dimension::value_type> pad_low(input_padding._lower_size.begin(), input_padding._lower_size.begin() + input_rank);
357+
std::vector<ov::Dimension::value_type> pad_up(input_padding._upper_size.begin(), input_padding._upper_size.begin() + input_rank);
358+
359+
const size_t num_heads_dim = 1;
360+
pad_low.insert(pad_low.begin() + num_heads_dim, 0);
361+
pad_up.insert(pad_up.begin() + num_heads_dim, 0);
362+
363+
return padding(pad_low, pad_up, input_padding._dynamic_dims_mask);
364+
};
365+
351366
const auto attn_mask_idx = 3;
352367
if (updated_impl_params.input_layouts.size() > attn_mask_idx) {
353368
const auto attn_mask_shape = updated_impl_params.input_layouts[attn_mask_idx].get_partial_shape();
@@ -356,8 +371,12 @@ kernel_impl_params SDPABase::static_canonicalize_shapes(const kernel_impl_params
356371

357372
// For scale of 1D tensor or attention_mask of empty shape, use extend_shape_to_rank_from_end as before
358373
for (auto& input_layout : updated_impl_params.input_layouts) {
359-
input_layout.set_partial_shape(input_layout.get_partial_shape().size() <= 1 ? extend_shape_to_rank_from_end(input_layout.get_partial_shape())
360-
: extend_pshape_to_rank_in_num_heads_dim(input_layout.get_partial_shape()));
374+
size_t input_rank = input_layout.get_partial_shape().size();
375+
input_layout.set_partial_shape(input_rank <= 1 ? extend_shape_to_rank_from_end(input_layout.get_partial_shape())
376+
: extend_pshape_to_rank_in_num_heads_dim(input_layout.get_partial_shape()));
377+
if (input_layout.data_padding) {
378+
input_layout.data_padding = extend_padding_to_rank_in_num_heads_dim(input_layout.data_padding, input_rank);
379+
}
361380
}
362381

363382
auto& output_layout = updated_impl_params.output_layouts[0];

0 commit comments

Comments
 (0)