Skip to content

Commit dd092f7

Browse files
committed
Support pooling models
Signed-off-by: lianyibo <[email protected]>
1 parent 1705501 commit dd092f7

File tree

4 files changed

+109
-24
lines changed

4 files changed

+109
-24
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import torch
1616

1717

18-
def _generate_attn_mask(max_seq_len, dtype):
18+
def _generate_attn_mask(max_seq_len, dtype, tril):
19+
if not tril:
20+
return torch.zeros(size=(max_seq_len, max_seq_len)).to(dtype)
1921
# Construct lower triangle matrix.
2022
mask_flag = torch.tril(
2123
torch.ones((max_seq_len, max_seq_len),
@@ -40,12 +42,13 @@ def __init__(
4042
max_seq_len: int,
4143
dtype: torch.dtype,
4244
device: torch.device = None,
45+
tril: bool = True,
4346
):
4447
# NOTE: The device argument specifies the target NPU
4548
# to be used for the newly added FIA operator.
4649
# Only pass this parameter when using the new FIA operator.
47-
48-
attn_mask = _generate_attn_mask(max_seq_len, dtype)
50+
self.tril = tril
51+
attn_mask = _generate_attn_mask(max_seq_len, dtype, self.tril)
4952

5053
self._seq_len_cached = attn_mask.shape[0]
5154
self.attn_mask_cache = attn_mask
@@ -103,6 +106,7 @@ def get_splitfuse_attn_mask(
103106
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
104107
if seqlen > self._seq_len_cached:
105108
self._seq_len_cached = seqlen
106-
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
109+
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype,
110+
self.tril)
107111
if self.attn_mask_cache.dtype != dtype:
108112
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

vllm_ascend/attention/attention_v1.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,27 @@ def __init__(
287287
self.key_cache = None
288288
self.value_cache = None
289289

290+
def _forward_encoder(
291+
self,
292+
query: torch.Tensor,
293+
key: torch.Tensor,
294+
value: torch.Tensor,
295+
attn_metadata: AscendMetadata,
296+
output: Optional[torch.Tensor] = None,
297+
num_tokens=0,
298+
) -> torch.Tensor:
299+
torch_npu._npu_flash_attention(query=query,
300+
key=key,
301+
value=value,
302+
mask=attn_metadata.attn_mask,
303+
seq_len=attn_metadata.seq_lens,
304+
scale_value=self.scale,
305+
num_heads=self.num_heads,
306+
num_kv_heads=self.num_kv_heads,
307+
out=output)
308+
assert output is not None
309+
return output[:num_tokens, :, :]
310+
290311
def _forward_prefill_no_cache(
291312
self,
292313
query: torch.Tensor,
@@ -570,10 +591,11 @@ def forward(
570591
num_actual_tokens = attn_metadata.num_actual_tokens
571592
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
572593
attn_type = self.attn_type
573-
if attn_type != AttentionType.DECODER:
574-
raise NotImplementedError("Encoder self-attention and "
575-
"encoder/decoder cross-attention "
576-
"are not implemented for "
594+
if attn_type not in [
595+
AttentionType.DECODER, AttentionType.ENCODER_ONLY
596+
]:
597+
raise NotImplementedError("Encoder/Decoder cross-attention "
598+
"is not implemented for "
577599
"PallasAttentionBackendImpl")
578600
# View q k v to BSH.
579601
query = query.view(-1, self.num_heads, self.head_size)
@@ -594,7 +616,11 @@ def forward(
594616
slot_indices=slots)
595617

596618
# V0-Style scheduler situation.
597-
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
619+
if attn_type == AttentionType.ENCODER_ONLY:
620+
output = self._forward_encoder(query, key, value,
621+
attn_metadata, output,
622+
num_tokens)
623+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
598624
output = self._forward_prefill_no_cache(
599625
query, key, value, attn_metadata, output, num_tokens)
600626
elif attn_metadata.attn_state == \

vllm_ascend/platform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
145145
structured_outputs_config.backend == "auto" and \
146146
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
147147
not scheduler_config.send_delta_data and \
148-
scheduler_config.policy == "fcfs":
148+
scheduler_config.policy == "fcfs" and \
149+
model_config.runner_type == "generate":
149150
ascend_scheduler_config.enabled = True
150151
chunked_prefill_enabled_in_ascend_scheduler = getattr(
151152
ascend_scheduler_config, "enable_chunked_prefill", False)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
7777
# yapf conflicts with isort for this block
7878
# yapf: disable
79-
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
80-
KVCacheConfig, KVCacheGroupSpec,
81-
KVCacheSpec, MambaSpec)
79+
from vllm.v1.kv_cache_interface import (AttentionSpec,
80+
EncoderOnlyAttentionSpec,
81+
FullAttentionSpec, KVCacheConfig,
82+
KVCacheGroupSpec, KVCacheSpec,
83+
MambaSpec)
8284
# yapf: enable
8385
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
8486
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
@@ -324,13 +326,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
324326
self.block_size,
325327
use_mla=self.model_config.use_mla,
326328
)
329+
pooler_config = self.model_config.pooler_config
330+
tril = self.model_config.runner_type == "generate" or (
331+
pooler_config is not None
332+
and pooler_config.pooling_type.lower() == "last")
327333
if torch.version.cann.startswith("8.3"):
328334
self.attn_mask_builder = AttentionMaskBuilder(
329335
self.scheduler_config.max_num_batched_tokens, self.dtype,
330-
self.device)
336+
self.device, tril)
331337
else:
332338
self.attn_mask_builder = AttentionMaskBuilder(
333-
self.model_config.max_model_len, self.dtype)
339+
self.model_config.max_model_len, self.dtype, tril=tril)
334340

335341
# Set up speculative decoding.
336342
self.spec_attn_mask = None
@@ -1487,14 +1493,29 @@ def _prepare_inputs(
14871493
# in the same group share the same metadata.
14881494
for kv_cache_group_id, kv_cache_group_spec in enumerate(
14891495
self.kv_cache_config.kv_cache_groups):
1490-
blk_table = self.input_batch.block_table[kv_cache_group_id]
1491-
blk_table_tensor = blk_table.get_device_tensor()
1492-
slot_mapping = blk_table.slot_mapping_cpu[:
1493-
total_num_scheduled_tokens]
1494-
self.slot_mapping[:total_num_scheduled_tokens].copy_(
1495-
slot_mapping[:total_num_scheduled_tokens],
1496-
non_blocking=True,
1497-
)
1496+
if isinstance(kv_cache_group_spec.kv_cache_spec,
1497+
EncoderOnlyAttentionSpec):
1498+
# Encoder-only layers do not have KV cache, so we need to
1499+
# create a dummy block table and slot mapping for them.
1500+
blk_table_tensor = torch.zeros(
1501+
(num_reqs, 1),
1502+
dtype=torch.int32,
1503+
device=self.device,
1504+
)
1505+
slot_mapping = torch.zeros(
1506+
(total_num_scheduled_tokens, ),
1507+
dtype=torch.int64,
1508+
device=self.device,
1509+
)
1510+
else:
1511+
blk_table = self.input_batch.block_table[kv_cache_group_id]
1512+
blk_table_tensor = blk_table.get_device_tensor()
1513+
slot_mapping = blk_table.slot_mapping_cpu[:
1514+
total_num_scheduled_tokens]
1515+
self.slot_mapping[:total_num_scheduled_tokens].copy_(
1516+
slot_mapping[:total_num_scheduled_tokens],
1517+
non_blocking=True,
1518+
)
14981519

14991520
# Make AscendCommonAttentionMetadata
15001521
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1543,6 +1564,11 @@ def _prepare_inputs(
15431564
common_prefix_len=common_prefix_len,
15441565
common_attn_metadata=common_attn_metadata,
15451566
**extra_attn_metadata_args)
1567+
elif self.model_config.runner_type == "pooling":
1568+
attn_metadata_i = builder.build(
1569+
common_prefix_len=common_prefix_len,
1570+
common_attn_metadata=common_attn_metadata,
1571+
**extra_attn_metadata_args)
15461572
else:
15471573
attn_metadata_i = builder.build(
15481574
common_prefix_len=common_prefix_len,
@@ -2672,6 +2698,33 @@ def _convert_torch_format(self, tensor):
26722698
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
26732699
return tensor
26742700

2701+
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
2702+
"""
2703+
Add encoder-only layers to the KV cache config.
2704+
"""
2705+
block_size = self.vllm_config.cache_config.block_size
2706+
use_mla = self.vllm_config.model_config.use_mla
2707+
encoder_only_attn_specs: dict[AttentionSpec,
2708+
list[str]] = defaultdict(list)
2709+
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
2710+
for layer_name, attn_module in attn_layers.items():
2711+
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
2712+
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
2713+
block_size=block_size,
2714+
num_kv_heads=attn_module.num_kv_heads,
2715+
head_size=attn_module.head_size,
2716+
dtype=self.kv_cache_dtype,
2717+
use_mla=use_mla)
2718+
encoder_only_attn_specs[attn_spec].append(layer_name)
2719+
self.runner_only_attn_layers.add(layer_name)
2720+
if len(encoder_only_attn_specs) > 0:
2721+
assert len(
2722+
encoder_only_attn_specs
2723+
) == 1, "Only support one encoder-only attention spec now"
2724+
spec, layer_names = encoder_only_attn_specs.popitem()
2725+
self.kv_cache_config.kv_cache_groups.append(
2726+
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
2727+
26752728
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26762729
"""
26772730
Initialize KV cache based on `kv_cache_config`.
@@ -2681,9 +2734,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26812734
"""
26822735
kv_cache_config = deepcopy(kv_cache_config)
26832736
self.kv_cache_config = kv_cache_config
2737+
self.may_reinitialize_input_batch(kv_cache_config)
2738+
self.may_add_encoder_only_layers_to_kv_cache_config()
26842739
self.initialize_attn_backend(kv_cache_config)
26852740
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
2686-
self.may_reinitialize_input_batch(kv_cache_config)
26872741

26882742
if self.model_config.is_deepseek_mla:
26892743
kv_caches = self.initialize_kv_cache_tensors_deepseek(

0 commit comments

Comments
 (0)