@@ -66,8 +66,6 @@ class SfaCpContext:
6666 local_start : int
6767 local_end : int
6868 local_end_with_pad : int
69- pad_size : int
70- local_pad_size : int
7169 slot_mapping_cp : torch .Tensor
7270 actual_seq_lengths_query : torch .Tensor
7371 actual_seq_lengths_key : torch .Tensor
@@ -206,23 +204,41 @@ def build(
206204 sfa_cp_context = None
207205 if self .enable_sfa_cp :
208206 global_tp_size = get_tp_group ().world_size
209- num_tokens = num_actual_tokens
210- num_tokens_pad = _round_up (num_actual_tokens , global_tp_size )
207+ num_tokens = num_input_tokens
208+ num_tokens_pad = _round_up (num_tokens , global_tp_size )
211209 num_tokens_per_device = num_tokens_pad // global_tp_size
212- pad_size = num_tokens_pad - num_tokens
213210 local_start = get_tp_group ().rank_in_group * num_tokens_per_device
214211 local_end_with_pad = local_start + num_tokens_per_device
215212 local_end = min (local_end_with_pad , num_actual_tokens )
216- local_pad_size = local_end_with_pad - local_end
213+
214+ pad_size = num_tokens_pad - cos .shape [0 ]
215+ assert cos .shape == sin .shape , \
216+ f"cos.shape must be equal to sin.shape, got { cos .shape } and { sin .shape } "
217217
218218 if pad_size > 0 :
219219 cos = nn .functional .pad (cos , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
220220 sin = nn .functional .pad (sin , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
221- slot_mapping = nn .functional .pad (slot_mapping , (0 , pad_size ),
221+
222+ pad_size_slot = num_tokens_pad - slot_mapping .shape [0 ]
223+ if pad_size_slot > 0 :
224+ slot_mapping = nn .functional .pad (slot_mapping ,
225+ (0 , pad_size_slot ),
222226 value = - 1 )
227+ else :
228+ slot_mapping = slot_mapping [:num_tokens_pad ]
229+
223230 cos = cos [local_start :local_end_with_pad ]
224231 sin = sin [local_start :local_end_with_pad ]
225232 slot_mapping_cp = slot_mapping [local_start :local_end_with_pad ]
233+ assert cos .shape [0 ] == num_tokens_per_device , \
234+ f"cos.shape[0] must be equal to num_tokens_per_device, \
235+ got { cos .shape [0 ]} and { num_tokens_per_device } "
236+ assert slot_mapping_cp .shape [0 ] == num_tokens_per_device , \
237+ f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
238+ got { slot_mapping_cp .shape [0 ]} and { num_tokens_per_device } "
239+ assert slot_mapping .shape [0 ] == num_tokens_pad , \
240+ f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
241+ got { slot_mapping .shape [0 ]} and { num_tokens_pad } "
226242
227243 actual_seq_lengths_query = torch .empty_like (cum_query_lens )
228244 actual_seq_lengths_key = torch .empty_like (seq_lens )
@@ -254,8 +270,6 @@ def build(
254270 local_start = local_start ,
255271 local_end = local_end ,
256272 local_end_with_pad = local_end_with_pad ,
257- pad_size = pad_size ,
258- local_pad_size = local_pad_size ,
259273 slot_mapping_cp = slot_mapping_cp ,
260274 actual_seq_lengths_query = actual_seq_lengths_query ,
261275 actual_seq_lengths_key = actual_seq_lengths_key ,
0 commit comments