diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index 2dfc0072126b..64d42432c74b 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -68,3 +68,34 @@ def test_modernbert_models( hf_output = torch.tensor(hf_output).cpu().float() vllm_output = torch.tensor(vllm_output).cpu().float() assert torch.allclose(hf_output, vllm_output, atol=1e-2) + + +@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"]) +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_auto_conversion( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.token_classify(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3f835a8b88e3..4e7c3faa1d34 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -573,6 +573,7 @@ def check_available_online( "Qwen3ForSequenceClassification": _HfExamplesInfo( "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" ), + "Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"), } _MULTIMODAL_EXAMPLE_MODELS = { diff --git a/vllm/config/model.py b/vllm/config/model.py index 10e4d653c825..7ff095bcb9cc 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1796,6 +1796,7 @@ def get_served_model_name(model: str, served_model_name: str | list[str] | None) ("ForTextEncoding", ("pooling", "embed")), ("EmbeddingModel", ("pooling", "embed")), ("ForSequenceClassification", ("pooling", "classify")), + ("ForTokenClassification", ("pooling", "classify")), ("ForAudioClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")), diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 9ba76f312eda..504de9fe1087 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -337,6 +337,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(text_config, "classifier_from_token", None) method = getattr(text_config, "method", None) + def auto_set_score_bias(weights): + for name, weight in weights: + if name == "score.bias": + device = self.score.weight.device + dtype = self.score.weight.dtype + bias = weight.to(device).to(dtype) + self.score.bias = torch.nn.Parameter(bias) + self.score.skip_bias_add = False + else: + yield name, weight + + weights = auto_set_score_bias(weights) if tokens is None and method is None: return super().load_weights(weights) else: