Skip to content

Commit a889309

Browse files
authored
Merge branch 'main' into lint
2 parents e5526be + 38ba061 commit a889309

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+656
-296
lines changed

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727

2828
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
2929
bool kIsVariableB_, bool kIsVariableC_,
30-
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_>
30+
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
3131
struct Selective_Scan_fwd_kernel_traits {
3232
static_assert(kNItems_ % 4 == 0);
3333
using input_t = input_t_;
3434
using weight_t = weight_t_;
35+
using state_t = state_t_;
3536
static constexpr int kNThreads = kNThreads_;
3637
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
3738
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
@@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
132133
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
133134
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
134135
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
135-
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
136+
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
136137
cache_index * params.ssm_states_batch_stride +
137138
dim_id * kNRows * params.ssm_states_dim_stride;
138139

@@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
261262
if (threadIdx.x == 0) {
262263
smem_running_prefix[state_idx] = prefix_op.running_prefix;
263264
if (chunk == n_chunks - 1) {
264-
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
265+
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
265266
}
266267
}
267268
#pragma unroll
@@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
310311
}
311312
}
312313

313-
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
314+
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
314315
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
315316
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
316317
// processing 1 row.
@@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
321322
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
322323
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
323324
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
324-
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
325+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
325326
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
326327
dim3 grid(params.batch, params.dim / kNRows);
327328
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
@@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
341342
});
342343
}
343344

344-
template<typename input_t, typename weight_t>
345+
template<typename input_t, typename weight_t, typename state_t>
345346
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
346347

347348
#ifndef USE_ROCM
348349
if (params.seqlen <= 128) {
349-
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
350+
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
350351
} else if (params.seqlen <= 256) {
351-
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
352+
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
352353
} else if (params.seqlen <= 512) {
353-
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
354+
selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
354355
} else if (params.seqlen <= 1024) {
355-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
356+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
356357
} else {
357-
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
358+
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
358359
}
359360
#else
360361
if (params.seqlen <= 256) {
361-
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
362+
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
362363
} else if (params.seqlen <= 512) {
363-
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
364+
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
364365
} else if (params.seqlen <= 1024) {
365-
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
366+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
366367
} else {
367-
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
368+
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
368369
}
369370
#endif
370371
}
371372

372-
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
373-
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
374-
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
373+
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
374+
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
375+
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
376+
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
377+
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);
375378

376379
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
377380

378-
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
381+
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
379382
if (ITYPE == at::ScalarType::Half) { \
380383
using input_t = at::Half; \
381384
using weight_t = float; \
382-
__VA_ARGS__(); \
385+
if (STYPE == at::ScalarType::Half) { \
386+
using state_t = at::Half; \
387+
__VA_ARGS__(); \
388+
} else if (STYPE == at::ScalarType::Float) { \
389+
using state_t = float; \
390+
__VA_ARGS__(); \
391+
} else { \
392+
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
393+
} \
383394
} else if (ITYPE == at::ScalarType::BFloat16) { \
384395
using input_t = at::BFloat16; \
385396
using weight_t = float; \
386-
__VA_ARGS__(); \
397+
if (STYPE == at::ScalarType::BFloat16) { \
398+
using state_t = at::BFloat16; \
399+
__VA_ARGS__(); \
400+
} else if (STYPE == at::ScalarType::Float) { \
401+
using state_t = float; \
402+
__VA_ARGS__(); \
403+
} else { \
404+
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
405+
} \
387406
} else if (ITYPE == at::ScalarType::Float) { \
388407
using input_t = float; \
389408
using weight_t = float; \
409+
using state_t = float; \
390410
__VA_ARGS__(); \
391411
} else { \
392412
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
393413
}
394414

395415

396-
template<typename input_t, typename weight_t>
416+
template<typename input_t, typename weight_t, typename state_t>
397417
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
398418

399419
void set_ssm_params_fwd(SSMParamsBase &params,
@@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
648668

649669
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
650670
at::Tensor out = delta;
651-
TORCH_CHECK(ssm_states.scalar_type() == input_type);
671+
// ssm_states can now be either the same as input_type or float32
672+
auto state_type = ssm_states.scalar_type();
673+
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
652674
TORCH_CHECK(ssm_states.is_cuda());
653675
TORCH_CHECK(ssm_states.stride(-1) == 1);
654676

@@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
670692

671693
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
672694
auto stream = at::cuda::getCurrentCUDAStream().stream();
673-
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
674-
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
695+
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
696+
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
675697
});
676698
}

docs/configuration/optimization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
174174

175175
Known supported models:
176176

177+
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
177178
- Kimi-VL (<gh-pr:23817>)
178179
- Llama4 (<gh-pr:18368>)
179180
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)

docs/deployment/frameworks/lws.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Deploy the following yaml file `lws.yaml`
2222
metadata:
2323
name: vllm
2424
spec:
25-
replicas: 2
25+
replicas: 1
2626
leaderWorkerTemplate:
2727
size: 2
2828
restartPolicy: RecreateGroupOnPodRestart
@@ -41,7 +41,7 @@ Deploy the following yaml file `lws.yaml`
4141
- sh
4242
- -c
4343
- "bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=$(LWS_GROUP_SIZE);
44-
python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2"
44+
vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2"
4545
resources:
4646
limits:
4747
nvidia.com/gpu: "8"
@@ -126,8 +126,6 @@ Should get an output similar to this:
126126
NAME READY STATUS RESTARTS AGE
127127
vllm-0 1/1 Running 0 2s
128128
vllm-0-1 1/1 Running 0 2s
129-
vllm-1 1/1 Running 0 2s
130-
vllm-1-1 1/1 Running 0 2s
131129
```
132130

133131
Verify that the distributed tensor-parallel inference works:

examples/online_serving/multi-node-serving.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# Example usage:
1212
# On the head node machine, start the Ray head node process and run a vLLM server.
1313
# ./multi-node-serving.sh leader --ray_port=6379 --ray_cluster_size=<SIZE> [<extra ray args>] && \
14-
# python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2
14+
# vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2
1515
#
1616
# On each worker node, start the Ray worker node process.
1717
# ./multi-node-serving.sh worker --ray_address=<HEAD_NODE_IP> --ray_port=6379 [<extra ray args>]

examples/online_serving/openai_chat_completion_client_for_multimodal.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,52 @@ def run_audio(model: str) -> None:
266266
print("Chat completion output from base64 encoded audio:", result)
267267

268268

269+
def run_multi_audio(model: str) -> None:
270+
from vllm.assets.audio import AudioAsset
271+
272+
# Two different audios to showcase batched inference.
273+
audio_url = AudioAsset("winning_call").url
274+
audio_base64 = encode_base64_content_from_url(audio_url)
275+
audio_url2 = AudioAsset("azacinto_foscolo").url
276+
audio_base64_2 = encode_base64_content_from_url(audio_url2)
277+
278+
# OpenAI-compatible schema (`input_audio`)
279+
chat_completion_from_base64 = client.chat.completions.create(
280+
messages=[
281+
{
282+
"role": "user",
283+
"content": [
284+
{"type": "text", "text": "Are these two audios the same?"},
285+
{
286+
"type": "input_audio",
287+
"input_audio": {
288+
"data": audio_base64,
289+
"format": "wav",
290+
},
291+
},
292+
{
293+
"type": "input_audio",
294+
"input_audio": {
295+
"data": audio_base64_2,
296+
"format": "wav",
297+
},
298+
},
299+
],
300+
}
301+
],
302+
model=model,
303+
max_completion_tokens=64,
304+
)
305+
306+
result = chat_completion_from_base64.choices[0].message.content
307+
print("Chat completion output from input audio:", result)
308+
309+
269310
example_function_map = {
270311
"text-only": run_text_only,
271312
"single-image": run_single_image,
272313
"multi-image": run_multi_image,
314+
"multi-audio": run_multi_audio,
273315
"video": run_video,
274316
"audio": run_audio,
275317
}

tests/async_engine/test_api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_api_server(api_server, distributed_executor_backend: str):
9898
pool.join()
9999

100100
# check cancellation stats
101-
# give it some times to update the stats
101+
# give it some time to update the stats
102102
time.sleep(1)
103103

104104
num_aborted_requests = requests.get(

tests/core/block/e2e/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,10 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
439439
@pytest.mark.parametrize("seed", [1])
440440
def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator,
441441
test_llm_generator):
442-
"""Verify block manager v2 with auto prefix caching could works normal
442+
"""Verify block manager v2 with auto prefix caching could work normally
443443
even when eviction started.
444444
With APC enabled, all blocks are held by native block at the beginning.
445-
Then blocks are managed by evictor instead. If cache hit at the evitor's
445+
Then blocks are managed by evictor instead. If cache hit at the evictor's
446446
block, then it could be reused, or we need to recompute its kv cache.
447447
"""
448448
output_len = 10

tests/engine/test_arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_get_kwargs():
167167
# dict should have json tip in help
168168
json_tip = "Should either be a valid JSON string or JSON keys"
169169
assert json_tip in kwargs["json_tip"]["help"]
170-
# nested config should should construct the nested config
170+
# nested config should construct the nested config
171171
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
172172

173173

tests/entrypoints/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,32 @@ def sample_sql_statements():
201201
condition: column "=" number
202202
number: "1" | "2"
203203
""")
204+
205+
206+
@pytest.fixture(scope="session")
207+
def zephyr_lora_files():
208+
"""Download zephyr LoRA files once per test session."""
209+
from huggingface_hub import snapshot_download
210+
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
211+
212+
213+
@pytest.fixture(scope="session")
214+
def zephyr_lora_added_tokens_files(zephyr_lora_files):
215+
"""Create zephyr LoRA files with added tokens once per test session."""
216+
import shutil
217+
from tempfile import TemporaryDirectory
218+
219+
from transformers import AutoTokenizer
220+
221+
tmp_dir = TemporaryDirectory()
222+
tmp_model_dir = f"{tmp_dir.name}/zephyr"
223+
shutil.copytree(zephyr_lora_files, tmp_model_dir)
224+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
225+
# Copy tokenizer to adapter and add some unique tokens
226+
# 32000, 32001, 32002
227+
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
228+
special_tokens=True)
229+
assert added == 3
230+
tokenizer.save_pretrained(tmp_model_dir)
231+
yield tmp_model_dir
232+
tmp_dir.cleanup()

tests/entrypoints/openai/test_chat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from openai import BadRequestError, OpenAI
1616

1717
from ...utils import RemoteOpenAIServer
18-
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
19-
from .test_completion import zephyr_lora_files # noqa: F401
2018

2119
# any model with a chat template should work here
2220
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

0 commit comments

Comments
 (0)