|
17 | 17 | from transformers.utils import torch_int |
18 | 18 |
|
19 | 19 | from vllm.attention.backends.registry import _Backend |
20 | | -from vllm.attention.layer import check_upstream_fa_availability |
| 20 | +from vllm.attention.layer import ( |
| 21 | + maybe_get_vit_flash_attn_backend, |
| 22 | +) |
21 | 23 | from vllm.config import VllmConfig |
22 | 24 | from vllm.config.multimodal import BaseDummyOptions |
23 | 25 | from vllm.distributed import get_tensor_model_parallel_world_size |
|
56 | 58 | PromptUpdate, |
57 | 59 | ) |
58 | 60 | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
| 61 | +from vllm.platforms import current_platform |
59 | 62 | from vllm.sequence import IntermediateTensors |
60 | 63 | from vllm.utils.tensor_schema import TensorSchema, TensorShape |
61 | 64 |
|
62 | 65 | from .interfaces import ( |
63 | 66 | MultiModalEmbeddings, |
64 | 67 | SupportsLoRA, |
| 68 | + SupportsMRoPE, |
65 | 69 | SupportsMultiModal, |
66 | 70 | SupportsPP, |
67 | 71 | ) |
@@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt( |
337 | 341 | cos = cos.chunk(2, dim=-1)[0].contiguous() |
338 | 342 | sin = sin.chunk(2, dim=-1)[0].contiguous() |
339 | 343 |
|
340 | | - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb |
| 344 | + if current_platform.is_cuda(): |
| 345 | + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb |
| 346 | + elif current_platform.is_rocm(): |
| 347 | + from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb |
341 | 348 |
|
342 | 349 | q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) |
343 | 350 | k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) |
@@ -398,18 +405,28 @@ def __init__( |
398 | 405 | attn_backend_override=attn_backend_override, |
399 | 406 | ) |
400 | 407 |
|
401 | | - self.use_upstream_fa = False |
402 | | - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( |
403 | | - torch.get_default_dtype() |
404 | | - ): |
405 | | - self.attn_backend = _Backend.FLASH_ATTN |
406 | | - self.use_upstream_fa = True |
| 408 | + self.attn_backend, self.flash_attn_varlen_func = ( |
| 409 | + maybe_get_vit_flash_attn_backend( |
| 410 | + self.attn_backend, |
| 411 | + use_upstream_fa=False, |
| 412 | + attn_backend_override=attn_backend_override, |
| 413 | + ) |
| 414 | + ) |
407 | 415 |
|
408 | | - if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: |
| 416 | + if self.attn_backend not in { |
| 417 | + _Backend.FLASH_ATTN, |
| 418 | + _Backend.XFORMERS, |
| 419 | + _Backend.ROCM_AITER_FA, |
| 420 | + }: |
409 | 421 | raise RuntimeError( |
410 | 422 | f"Keye-VL does not support {self.attn_backend} backend now." |
411 | 423 | ) |
412 | 424 |
|
| 425 | + self.is_flash_attn_backend = self.attn_backend in { |
| 426 | + _Backend.FLASH_ATTN, |
| 427 | + _Backend.ROCM_AITER_FA, |
| 428 | + } |
| 429 | + |
413 | 430 | def forward( |
414 | 431 | self, |
415 | 432 | hidden_states: torch.Tensor, |
@@ -457,15 +474,10 @@ def forward( |
457 | 474 | self.head_dim, |
458 | 475 | ) |
459 | 476 |
|
460 | | - if self.attn_backend == _Backend.FLASH_ATTN: |
461 | | - if self.use_upstream_fa: |
462 | | - from flash_attn import flash_attn_varlen_func |
463 | | - else: |
464 | | - from vllm.vllm_flash_attn import flash_attn_varlen_func |
465 | | - |
| 477 | + if self.is_flash_attn_backend: |
466 | 478 | q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) |
467 | 479 |
|
468 | | - output = flash_attn_varlen_func( |
| 480 | + output = self.flash_attn_varlen_func( |
469 | 481 | q, |
470 | 482 | k, |
471 | 483 | v, |
@@ -1542,7 +1554,7 @@ def get_mm_mapping(self) -> MultiModelKeys: |
1542 | 1554 | dummy_inputs=KeyeDummyInputsBuilder, |
1543 | 1555 | ) |
1544 | 1556 | class KeyeForConditionalGeneration( |
1545 | | - BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP |
| 1557 | + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE |
1546 | 1558 | ): |
1547 | 1559 | def _build_projector( |
1548 | 1560 | self, |
@@ -1611,3 +1623,142 @@ def _process_video_input( |
1611 | 1623 | return tuple( |
1612 | 1624 | self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) |
1613 | 1625 | ) |
| 1626 | + |
| 1627 | + def get_mrope_input_positions( |
| 1628 | + self, |
| 1629 | + input_tokens: list[int], |
| 1630 | + hf_config: PretrainedConfig, |
| 1631 | + image_grid_thw: list[list[int]] | torch.Tensor, |
| 1632 | + video_grid_thw: list[list[int]] | torch.Tensor, |
| 1633 | + context_len: int = 0, |
| 1634 | + seq_len: int | None = None, |
| 1635 | + second_per_grid_ts: list[float] | None = None, |
| 1636 | + audio_feature_lengths: torch.Tensor | None = None, |
| 1637 | + use_audio_in_video: bool = False, |
| 1638 | + ) -> tuple[torch.Tensor, int]: |
| 1639 | + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: |
| 1640 | + video_grid_thw = video_grid_thw[0] |
| 1641 | + """Get mrope input positions and delta value (Keye series).""" |
| 1642 | + |
| 1643 | + def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: |
| 1644 | + """ |
| 1645 | + Split grid_thw along the t dimension. |
| 1646 | +
|
| 1647 | + Args: |
| 1648 | + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. |
| 1649 | +
|
| 1650 | + Returns: |
| 1651 | + List of [1, h, w] rows, repeated t times for each original row. |
| 1652 | + """ |
| 1653 | + |
| 1654 | + if isinstance(grid_thw, list): |
| 1655 | + grid_thw = torch.tensor(grid_thw, dtype=torch.long) |
| 1656 | + |
| 1657 | + if grid_thw.numel() == 0: |
| 1658 | + return [] |
| 1659 | + |
| 1660 | + t, hw = grid_thw[:, 0], grid_thw[:, 1:] |
| 1661 | + ones = torch.ones_like(hw[:, :1]) # [N,1] |
| 1662 | + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) |
| 1663 | + return out.tolist() |
| 1664 | + |
| 1665 | + video_grid_thw = split_thw(video_grid_thw) |
| 1666 | + |
| 1667 | + image_token_id = hf_config.image_token_id |
| 1668 | + video_token_id = hf_config.video_token_id |
| 1669 | + spatial_merge_size = hf_config.vision_config.spatial_merge_size |
| 1670 | + |
| 1671 | + image_nums = len(image_grid_thw) |
| 1672 | + frame_nums = len(video_grid_thw) |
| 1673 | + llm_pos_ids_list: list = [] |
| 1674 | + |
| 1675 | + st = 0 |
| 1676 | + remain_images, remain_frames = image_nums, frame_nums |
| 1677 | + |
| 1678 | + image_index, video_index = 0, 0 |
| 1679 | + for _ in range(image_nums + frame_nums): |
| 1680 | + if remain_images > 0: |
| 1681 | + try: |
| 1682 | + ed_image = input_tokens.index(image_token_id, st) |
| 1683 | + except ValueError: |
| 1684 | + ed_image = len(input_tokens) + 1 |
| 1685 | + else: |
| 1686 | + ed_image = len(input_tokens) + 1 |
| 1687 | + if remain_frames > 0: |
| 1688 | + try: |
| 1689 | + ed_video = input_tokens.index(video_token_id, st) |
| 1690 | + except ValueError: |
| 1691 | + ed_video = len(input_tokens) + 1 |
| 1692 | + else: |
| 1693 | + ed_video = len(input_tokens) + 1 |
| 1694 | + |
| 1695 | + if ed_image < ed_video: |
| 1696 | + t, h, w = ( |
| 1697 | + image_grid_thw[image_index][0], |
| 1698 | + image_grid_thw[image_index][1], |
| 1699 | + image_grid_thw[image_index][2], |
| 1700 | + ) |
| 1701 | + image_index += 1 |
| 1702 | + remain_images -= 1 |
| 1703 | + ed = ed_image |
| 1704 | + else: |
| 1705 | + t, h, w = ( |
| 1706 | + video_grid_thw[video_index][0], |
| 1707 | + video_grid_thw[video_index][1], |
| 1708 | + video_grid_thw[video_index][2], |
| 1709 | + ) |
| 1710 | + video_index += 1 |
| 1711 | + remain_frames -= 1 |
| 1712 | + ed = ed_video |
| 1713 | + |
| 1714 | + llm_grid_t, llm_grid_h, llm_grid_w = ( |
| 1715 | + t, |
| 1716 | + h // spatial_merge_size, |
| 1717 | + w // spatial_merge_size, |
| 1718 | + ) |
| 1719 | + text_len = ed - st |
| 1720 | + |
| 1721 | + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| 1722 | + llm_pos_ids_list.append( |
| 1723 | + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx |
| 1724 | + ) |
| 1725 | + |
| 1726 | + t_index = ( |
| 1727 | + ( |
| 1728 | + torch.arange(llm_grid_t) |
| 1729 | + .view(-1, 1) |
| 1730 | + .expand(-1, llm_grid_h * llm_grid_w) |
| 1731 | + ) |
| 1732 | + .long() |
| 1733 | + .flatten() |
| 1734 | + ) |
| 1735 | + |
| 1736 | + h_index = ( |
| 1737 | + torch.arange(llm_grid_h) |
| 1738 | + .view(1, -1, 1) |
| 1739 | + .expand(llm_grid_t, -1, llm_grid_w) |
| 1740 | + .flatten() |
| 1741 | + ) |
| 1742 | + w_index = ( |
| 1743 | + torch.arange(llm_grid_w) |
| 1744 | + .view(1, 1, -1) |
| 1745 | + .expand(llm_grid_t, llm_grid_h, -1) |
| 1746 | + .flatten() |
| 1747 | + ) |
| 1748 | + llm_pos_ids_list.append( |
| 1749 | + torch.stack([t_index, h_index, w_index]) + text_len + st_idx |
| 1750 | + ) |
| 1751 | + st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
| 1752 | + |
| 1753 | + if st < len(input_tokens): |
| 1754 | + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| 1755 | + text_len = len(input_tokens) - st |
| 1756 | + llm_pos_ids_list.append( |
| 1757 | + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx |
| 1758 | + ) |
| 1759 | + |
| 1760 | + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
| 1761 | + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() |
| 1762 | + llm_positions = llm_positions[:, context_len:seq_len] |
| 1763 | + |
| 1764 | + return llm_positions, mrope_position_delta |
0 commit comments