Skip to content

Commit 06ecf32

Browse files
noooopLucaskabela
authored andcommitted
[Model] Automatic conversion of TokenClassification model (vllm-project#30666)
Signed-off-by: wang.yuqi <[email protected]>
1 parent 0e91808 commit 06ecf32

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

tests/models/language/pooling/test_token_classification.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,34 @@ def test_modernbert_models(
6868
hf_output = torch.tensor(hf_output).cpu().float()
6969
vllm_output = torch.tensor(vllm_output).cpu().float()
7070
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
71+
72+
73+
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
74+
@pytest.mark.parametrize("dtype", ["float"])
75+
@torch.inference_mode
76+
def test_auto_conversion(
77+
hf_runner,
78+
vllm_runner,
79+
example_prompts,
80+
model: str,
81+
dtype: str,
82+
) -> None:
83+
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
84+
vllm_outputs = vllm_model.token_classify(example_prompts)
85+
86+
with hf_runner(
87+
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
88+
) as hf_model:
89+
tokenizer = hf_model.tokenizer
90+
hf_outputs = []
91+
for prompt in example_prompts:
92+
inputs = tokenizer([prompt], return_tensors="pt")
93+
inputs = hf_model.wrap_device(inputs)
94+
output = hf_model.model(**inputs)
95+
hf_outputs.append(softmax(output.logits[0]))
96+
97+
# check logits difference
98+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
99+
hf_output = torch.tensor(hf_output).cpu().float()
100+
vllm_output = torch.tensor(vllm_output).cpu().float()
101+
assert torch.allclose(hf_output, vllm_output, atol=1e-2)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ def check_available_online(
573573
"Qwen3ForSequenceClassification": _HfExamplesInfo(
574574
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
575575
),
576+
"Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
576577
}
577578

578579
_MULTIMODAL_EXAMPLE_MODELS = {

vllm/config/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,6 +1796,7 @@ def get_served_model_name(model: str, served_model_name: str | list[str] | None)
17961796
("ForTextEncoding", ("pooling", "embed")),
17971797
("EmbeddingModel", ("pooling", "embed")),
17981798
("ForSequenceClassification", ("pooling", "classify")),
1799+
("ForTokenClassification", ("pooling", "classify")),
17991800
("ForAudioClassification", ("pooling", "classify")),
18001801
("ForImageClassification", ("pooling", "classify")),
18011802
("ForVideoClassification", ("pooling", "classify")),

vllm/model_executor/models/adapters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
337337
tokens = getattr(text_config, "classifier_from_token", None)
338338
method = getattr(text_config, "method", None)
339339

340+
def auto_set_score_bias(weights):
341+
for name, weight in weights:
342+
if name == "score.bias":
343+
device = self.score.weight.device
344+
dtype = self.score.weight.dtype
345+
bias = weight.to(device).to(dtype)
346+
self.score.bias = torch.nn.Parameter(bias)
347+
self.score.skip_bias_add = False
348+
else:
349+
yield name, weight
350+
351+
weights = auto_set_score_bias(weights)
340352
if tokens is None and method is None:
341353
return super().load_weights(weights)
342354
else:

0 commit comments

Comments
 (0)