|
80 | 80 | from vllm.v1.sample.metadata import SamplingMetadata |
81 | 81 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
82 | 82 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
| 83 | +from vllm.v1.structured_output.utils import apply_grammar_bitmask |
83 | 84 | from vllm.v1.utils import CpuGpuBuffer |
84 | 85 | from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput |
85 | 86 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
@@ -1559,70 +1560,6 @@ def _calc_spec_decode_metadata( |
1559 | 1560 | ) |
1560 | 1561 | return metadata |
1561 | 1562 |
|
1562 | | - def apply_grammar_bitmask( |
1563 | | - self, |
1564 | | - scheduler_output: "SchedulerOutput", |
1565 | | - logits: torch.Tensor, |
1566 | | - ) -> torch.Tensor: |
1567 | | - grammar_bitmask = scheduler_output.grammar_bitmask |
1568 | | - |
1569 | | - # We receive the structured output bitmask from the scheduler, |
1570 | | - # compacted to contain bitmasks only for structured output requests. |
1571 | | - # The order of the requests in the bitmask is not guaranteed to be the |
1572 | | - # same as the order of the requests in the gpu runner's batch. We need |
1573 | | - # to sort the bitmask to match the order of the requests used here. |
1574 | | - |
1575 | | - # Get the batch indices of the structured output requests. |
1576 | | - # Keep track of the number of speculative tokens scheduled for every |
1577 | | - # request in the batch, as the logit indices are offset by this amount. |
1578 | | - struct_out_req_batch_indices: dict[str, int] = {} |
1579 | | - cumulative_offset = 0 |
1580 | | - seq = sorted(self.input_batch.req_id_to_index.items(), |
1581 | | - key=lambda x: x[1]) |
1582 | | - for req_id, batch_index in seq: |
1583 | | - logit_index = batch_index + cumulative_offset |
1584 | | - cumulative_offset += len( |
1585 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1586 | | - if req_id in scheduler_output.structured_output_request_ids: |
1587 | | - struct_out_req_batch_indices[req_id] = logit_index |
1588 | | - |
1589 | | - out_indices = [] |
1590 | | - |
1591 | | - # Reorder the bitmask to match the order of the requests in the batch. |
1592 | | - sorted_bitmask = np.zeros_like(grammar_bitmask, |
1593 | | - shape=(logits.shape[0], |
1594 | | - grammar_bitmask.shape[1])) |
1595 | | - cumulative_index = 0 |
1596 | | - seq = sorted(scheduler_output.structured_output_request_ids.items(), |
1597 | | - key=lambda x: x[1]) |
1598 | | - for req_id, _ in seq: |
1599 | | - logit_index = struct_out_req_batch_indices[req_id] |
1600 | | - num_spec_tokens = len( |
1601 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1602 | | - for i in range(1 + num_spec_tokens): |
1603 | | - sorted_bitmask[logit_index + i] = \ |
1604 | | - grammar_bitmask[cumulative_index + i] |
1605 | | - out_indices.append(logit_index + i) |
1606 | | - cumulative_index += 1 + num_spec_tokens |
1607 | | - grammar_bitmask = sorted_bitmask |
1608 | | - |
1609 | | - # Serialization of np.ndarray is much more efficient than a tensor, |
1610 | | - # so we receive it in that format. |
1611 | | - grammar_bitmask = torch.from_numpy(grammar_bitmask) |
1612 | | - |
1613 | | - # NOTE: |
1614 | | - # 1. XGrammar bitmask applying only supports CPU and GPU. |
1615 | | - # 2. The logits and bitmask should be on the same device. |
1616 | | - # 3. XGrammar logits on CPU only supports float32 dtype. |
1617 | | - logits_dtype = logits.dtype |
1618 | | - logits = logits.to("cpu").float() |
1619 | | - xgr.apply_token_bitmask_inplace( |
1620 | | - logits, |
1621 | | - grammar_bitmask, |
1622 | | - indices=out_indices, |
1623 | | - ) |
1624 | | - return logits.to(self.device).to(logits_dtype) |
1625 | | - |
1626 | 1563 | def propose_draft_token_ids( |
1627 | 1564 | self, |
1628 | 1565 | valid_sampled_token_ids: list[list[int]], |
@@ -1851,7 +1788,15 @@ def execute_model( |
1851 | 1788 |
|
1852 | 1789 | # Apply structured output bitmasks if present |
1853 | 1790 | if scheduler_output.grammar_bitmask is not None: |
1854 | | - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 1791 | + # NOTE: |
| 1792 | + # 1. XGrammar bitmask applying only supports CPU and GPU. |
| 1793 | + # 2. The logits and bitmask should be on the same device. |
| 1794 | + # 3. XGrammar logits on CPU only supports float32 dtype. |
| 1795 | + logits_dtype = logits.dtype |
| 1796 | + logits = logits.to("cpu").float() |
| 1797 | + apply_grammar_bitmask(scheduler_output, self.input_batch, |
| 1798 | + logits, self.device) |
| 1799 | + logits = logits.to(self.device).to(logits_dtype) |
1855 | 1800 |
|
1856 | 1801 | # Sample the next token and get logprobs if needed. |
1857 | 1802 | sampling_metadata = self.input_batch.sampling_metadata |
|
0 commit comments