Skip to content

Commit 9985234

Browse files
committed
Merge branch 'main' into dataset
2 parents 64e8468 + 56d0408 commit 9985234

File tree

129 files changed

+3712
-805
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+3712
-805
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ steps:
566566
- tests/models/multimodal
567567
commands:
568568
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
569-
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
570-
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
569+
- pytest -v -s models/multimodal/processing
571570

572571
- label: Multi-Modal Models Test (Standard)
573572
mirror_hardwares: [amdexperimental]
@@ -770,6 +769,11 @@ steps:
770769
- pytest -v -s plugins_tests/test_platform_plugins.py
771770
- pip uninstall vllm_add_dummy_platform -y
772771
# end platform plugin tests
772+
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
773+
- pip install -e ./plugins/prithvi_io_processor_plugin
774+
- pytest -v -s plugins_tests/test_io_processor_plugins.py
775+
- pip uninstall prithvi_io_processor_plugin -y
776+
# end io_processor plugins test
773777
# other tests continue here:
774778
- pytest -v -s plugins_tests/test_scheduler_plugins.py
775779
- pip install -e ./plugins/vllm_add_dummy_model

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727

2828
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
2929
bool kIsVariableB_, bool kIsVariableC_,
30-
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_>
30+
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
3131
struct Selective_Scan_fwd_kernel_traits {
3232
static_assert(kNItems_ % 4 == 0);
3333
using input_t = input_t_;
3434
using weight_t = weight_t_;
35+
using state_t = state_t_;
3536
static constexpr int kNThreads = kNThreads_;
3637
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
3738
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
132133
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
133134
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
134135
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
135-
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
136+
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
136137
cache_index * params.ssm_states_batch_stride +
137138
dim_id * kNRows * params.ssm_states_dim_stride;
138139

@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
261262
if (threadIdx.x == 0) {
262263
smem_running_prefix[state_idx] = prefix_op.running_prefix;
263264
if (chunk == n_chunks - 1) {
264-
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
265+
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
265266
}
266267
}
267268
#pragma unroll
@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
310311
}
311312
}
312313

313-
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
314+
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
314315
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
315316
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
316317
// processing 1 row.
@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
321322
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
322323
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
323324
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
324-
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
325+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
325326
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
326327
dim3 grid(params.batch, params.dim / kNRows);
327328
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
341342
});
342343
}
343344

344-
template<typename input_t, typename weight_t>
345+
template<typename input_t, typename weight_t, typename state_t>
345346
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
346347

347348
#ifndef USE_ROCM
348349
if (params.seqlen <= 128) {
349-
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
350+
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
350351
} else if (params.seqlen <= 256) {
351-
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
352+
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
352353
} else if (params.seqlen <= 512) {
353-
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
354+
selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
354355
} else if (params.seqlen <= 1024) {
355-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
356+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
356357
} else {
357-
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
358+
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
358359
}
359360
#else
360361
if (params.seqlen <= 256) {
361-
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
362+
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
362363
} else if (params.seqlen <= 512) {
363-
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
364+
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
364365
} else if (params.seqlen <= 1024) {
365-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
366+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
366367
} else {
367-
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
368+
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
368369
}
369370
#endif
370371
}
371372

372-
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
373-
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
374-
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
373+
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
374+
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
375+
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
376+
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
377+
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);
375378

376379
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
377380

378-
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
381+
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
379382
if (ITYPE == at::ScalarType::Half) { \
380383
using input_t = at::Half; \
381384
using weight_t = float; \
382-
__VA_ARGS__(); \
385+
if (STYPE == at::ScalarType::Half) { \
386+
using state_t = at::Half; \
387+
__VA_ARGS__(); \
388+
} else if (STYPE == at::ScalarType::Float) { \
389+
using state_t = float; \
390+
__VA_ARGS__(); \
391+
} else { \
392+
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
393+
} \
383394
} else if (ITYPE == at::ScalarType::BFloat16) { \
384395
using input_t = at::BFloat16; \
385396
using weight_t = float; \
386-
__VA_ARGS__(); \
397+
if (STYPE == at::ScalarType::BFloat16) { \
398+
using state_t = at::BFloat16; \
399+
__VA_ARGS__(); \
400+
} else if (STYPE == at::ScalarType::Float) { \
401+
using state_t = float; \
402+
__VA_ARGS__(); \
403+
} else { \
404+
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
405+
} \
387406
} else if (ITYPE == at::ScalarType::Float) { \
388407
using input_t = float; \
389408
using weight_t = float; \
409+
using state_t = float; \
390410
__VA_ARGS__(); \
391411
} else { \
392412
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
393413
}
394414

395415

396-
template<typename input_t, typename weight_t>
416+
template<typename input_t, typename weight_t, typename state_t>
397417
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
398418

399419
void set_ssm_params_fwd(SSMParamsBase &params,
@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
648668

649669
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
650670
at::Tensor out = delta;
651-
TORCH_CHECK(ssm_states.scalar_type() == input_type);
671+
// ssm_states can now be either the same as input_type or float32
672+
auto state_type = ssm_states.scalar_type();
673+
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
652674
TORCH_CHECK(ssm_states.is_cuda());
653675
TORCH_CHECK(ssm_states.stride(-1) == 1);
654676

@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
670692

671693
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
672694
auto stream = at::cuda::getCurrentCUDAStream().stream();
673-
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
674-
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
695+
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
696+
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
675697
});
676698
}

docker/Dockerfile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
432432
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
433433

434434
# Install DeepGEMM from source
435-
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
435+
ARG DEEPGEMM_GIT_REF
436436
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
437437
RUN --mount=type=cache,target=/root/.cache/uv \
438-
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \
439-
&& rm /tmp/install_deepgemm.sh
438+
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"}
440439

441440
# Install EP kernels(pplx-kernels and DeepEP), NixL
442441
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh

docs/configuration/optimization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
174174

175175
Known supported models:
176176

177+
- Kimi-VL (<gh-pr:23817>)
177178
- Llama4 (<gh-pr:18368>)
178179
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
179180
- Qwen2.5-VL (<gh-pr:22742>)

docs/contributing/profiling.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ apt install nsight-systems-cli
7373

7474
### Example commands and usage
7575

76+
When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues).
77+
7678
#### Offline Inference
7779

7880
For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# IO Processor Plugins
2+
3+
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
4+
5+
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
6+
7+
## Writing an IO Processor Plugin
8+
9+
IO Processor plugins implement the `IOProcessor` interface (<gh-file:vllm/plugins/io_processors/interface.py>):
10+
11+
```python
12+
IOProcessorInput = TypeVar('IOProcessorInput')
13+
IOProcessorOutput = TypeVar('IOProcessorOutput')
14+
15+
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
16+
17+
def __init__(self, vllm_config: VllmConfig):
18+
self.vllm_config = vllm_config
19+
20+
@abstractmethod
21+
def pre_process(
22+
self,
23+
prompt: IOProcessorInput,
24+
request_id: Optional[str] = None,
25+
**kwargs,
26+
) -> Union[PromptType, Sequence[PromptType]]:
27+
raise NotImplementedError
28+
29+
async def pre_process_async(
30+
self,
31+
prompt: IOProcessorInput,
32+
request_id: Optional[str] = None,
33+
**kwargs,
34+
) -> Union[PromptType, Sequence[PromptType]]:
35+
return self.pre_process(prompt, request_id, **kwargs)
36+
37+
@abstractmethod
38+
def post_process(self,
39+
model_output: Sequence[PoolingRequestOutput],
40+
request_id: Optional[str] = None,
41+
**kwargs) -> IOProcessorOutput:
42+
raise NotImplementedError
43+
44+
async def post_process_async(
45+
self,
46+
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
47+
request_id: Optional[str] = None,
48+
**kwargs,
49+
) -> IOProcessorOutput:
50+
collected_output = [item async for i, item in model_output]
51+
return self.post_process(collected_output, request_id, **kwargs)
52+
53+
@abstractmethod
54+
def parse_request(self, request: Any) -> IOProcessorInput:
55+
raise NotImplementedError
56+
57+
@abstractmethod
58+
def output_to_response(
59+
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
60+
raise NotImplementedError
61+
```
62+
63+
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
64+
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
65+
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
66+
67+
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is available here <gh-file:vllm/entrypoints/openai/serving_pooling_with_io_plugin.py>.
68+
69+
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online (<gh-file:examples/online_serving/prithvi_geospatial_mae.py>) and offline (<gh-file:examples/offline_inference/prithvi_geospatial_mae_io_processor.py>) inference examples.
70+
71+
## Using an IO Processor plugin
72+
73+
IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded:
74+
75+
1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode.
76+
2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json).
77+
78+
The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json).

docs/design/plugin_system.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ Every plugin has three parts:
4949

5050
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
5151

52+
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name.
53+
5254
## Guidelines for Writing Plugins
5355

5456
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.

docs/getting_started/installation/cpu/build.inc.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ cd vllm_source
1616
Third, install required dependencies:
1717

1818
```bash
19-
uv pip install -r requirements/cpu-build.txt --torch-backend auto
20-
uv pip install -r requirements/cpu.txt --torch-backend auto
19+
uv pip install -r requirements/cpu-build.txt --torch-backend cpu
20+
uv pip install -r requirements/cpu.txt --torch-backend cpu
2121
```
2222

2323
??? console "pip"

0 commit comments

Comments
 (0)