Skip to content

Commit 4efd413

Browse files
authored
Merge branch 'main' into main
2 parents 217f79a + a649767 commit 4efd413

File tree

14 files changed

+1663
-0
lines changed

14 files changed

+1663
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,8 @@
10301030
title: Emu3
10311031
- local: model_doc/evolla
10321032
title: Evolla
1033+
- local: model_doc/fast_vlm
1034+
title: FastVLM
10331035
- local: model_doc/flava
10341036
title: FLAVA
10351037
- local: model_doc/florence2
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
*This model was released on 2025-05-06 and added to Hugging Face Transformers on 2025-10-07.*
18+
19+
# FastVLM
20+
21+
<div class="flex flex-wrap space-x-1">
22+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
23+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
24+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
25+
</div>
26+
27+
## Overview
28+
29+
FastVLM is an open-source vision-language model featuring a novel hybrid vision encoder, FastViTHD. Leveraging reparameterizable convolutional layers, scaled input resolution, and a reduced number of visual tokens, FastVLM delivers high accuracy with exceptional efficiency. Its optimized architecture enables deployment even on edge devices, achieving ultra-low TTFT (time to first token) without sacrificing performance.
30+
31+
The model was proposed in [FastVLM: Efficient Vision Encoding for Vision Language Models](https://huggingface.co/papers/2412.13303) by Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel and Hadi Pouransari.
32+
33+
The abstract from the paper is the following:
34+
35+
*Scaling the input image resolution is essential for enhancing the performance of Vision Language Models (VLMs), particularly in text-rich image understanding tasks. However, popular visual encoders such as ViTs become inefficient at high resolutions due to the large number of tokens and high encoding latency. At different operational resolutions, the vision encoder of a VLM can be optimized along two axes: reducing encoding latency and minimizing the number of visual tokens passed to the LLM, thereby lowering overall latency. Based on a comprehensive efficiency analysis of the interplay between image resolution, vision latency, token count, and LLM size, we introduce FastVLM—a model that achieves an optimized trade-off between resolution, latency, and accuracy. FastVLM incorporates FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images. Unlike previous methods, FastVLM achieves the optimal balance between visual token count and image resolution solely by scaling the input image, eliminating the need for additional token pruning and simplifying the model design. In the LLaVA-1.5 setup, FastVLM achieves 3.2× improvement in time-to-first-token (TTFT) while maintaining similar performance on VLM benchmarks compared to prior works. Compared to LLaVa-OneVision at the highest resolution (1152×1152), FastVLM achieves better performance on key benchmarks like SeedBench, MMMU and DocVQA, using the same 0.5B LLM, but with 85× faster TTFT and a vision encoder that is 3.4× smaller.*
36+
37+
This model was contributed by [Kamila](https://github.com/kamila-chay).
38+
The original code can be found [here](https://github.com/apple/ml-fastvlm).
39+
40+
## Usage tips
41+
42+
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
43+
44+
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
45+
46+
**Important: **
47+
48+
Hugging Face models use SDPA by default; however, this model’s visual backbone supports only eager attention, so it automatically falls back to `"eager"`.
49+
50+
If you want to use a different attention implementation in the language decoder, make sure to set it explicitly, for example:
51+
52+
`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation={"text_config": "flash_attention_2"})`
53+
54+
Setting it for the entire model, e.g.
55+
56+
`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation="flash_attention_2")`
57+
58+
will result in an error.
59+
60+
### Formatting Prompts with Chat Templates
61+
62+
Each **checkpoint** is trained with a specific prompt format, depending on the underlying large language model backbone. To ensure correct formatting, use the processor’s `apply_chat_template` method.
63+
64+
**Important:**
65+
- You must construct a conversation history — passing a plain string won't work.
66+
- Each message should be a dictionary with `"role"` and `"content"` keys.
67+
- The `"content"` should be a list of dictionaries for different modalities like `"text"` and `"image"`.
68+
69+
## Usage examples
70+
71+
### Single input inference
72+
73+
74+
```python
75+
import torch
76+
from transformers import AutoProcessor, FastVlmForConditionalGeneration
77+
78+
# Load the model in half-precision
79+
model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto")
80+
processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
81+
82+
conversation = [
83+
{
84+
"role": "user",
85+
"content": [
86+
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
87+
{"type": "text", "text": "What is shown in this image?"},
88+
],
89+
},
90+
]
91+
92+
inputs = processor.apply_chat_template(
93+
conversation,
94+
add_generation_prompt=True,
95+
tokenize=True,
96+
return_dict=True,
97+
return_tensors="pt"
98+
).to(model.device, torch.bfloat16)
99+
100+
# Generate
101+
generate_ids = model.generate(**inputs, max_new_tokens=30)
102+
processor.batch_decode(generate_ids, skip_special_tokens=True)
103+
```
104+
105+
106+
### Batched inference
107+
108+
FastVLM also supports batched inference. Here is how you can do it:
109+
110+
```python
111+
import torch
112+
from transformers import AutoProcessor, FastVlmForConditionalGeneration
113+
114+
# Load the model in half-precision
115+
model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto")
116+
processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
117+
118+
119+
# Prepare a batch of two prompts
120+
conversation_1 = [
121+
{
122+
"role": "user",
123+
"content": [
124+
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
125+
{"type": "text", "text": "What is shown in this image?"},
126+
],
127+
},
128+
]
129+
130+
conversation_2 = [
131+
{
132+
"role": "user",
133+
"content": [
134+
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
135+
{"type": "text", "text": "What is shown in this image?"},
136+
],
137+
},
138+
]
139+
140+
inputs = processor.apply_chat_template(
141+
[conversation_1, conversation_2],
142+
add_generation_prompt=True,
143+
tokenize=True,
144+
return_dict=True,
145+
padding=True,
146+
return_tensors="pt"
147+
).to(model.device, torch.bfloat16)
148+
149+
150+
# Generate
151+
generate_ids = model.generate(**inputs, max_new_tokens=30)
152+
processor.batch_decode(generate_ids, skip_special_tokens=True)
153+
```
154+
155+
156+
## Note regarding reproducing original implementation
157+
158+
In order to match the logits of the [original implementation](https://github.com/apple/ml-fastvlm), one needs to use float32. In half precision the logit difference is higher due to tiny differences in how some ops are implemented in timm.
159+
160+
### Using Flash Attention 2
161+
162+
Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one).
163+
164+
## FastVlmConfig
165+
166+
[[autodoc]] FastVlmConfig
167+
168+
## FastVlmModel
169+
170+
[[autodoc]] FastVlmModel
171+
172+
## FastVlmForConditionalGeneration
173+
174+
[[autodoc]] FastVlmForConditionalGeneration
175+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
from .falcon import *
127127
from .falcon_h1 import *
128128
from .falcon_mamba import *
129+
from .fast_vlm import *
129130
from .fastspeech2_conformer import *
130131
from .flaubert import *
131132
from .flava import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
("falcon", "FalconConfig"),
149149
("falcon_h1", "FalconH1Config"),
150150
("falcon_mamba", "FalconMambaConfig"),
151+
("fast_vlm", "FastVlmConfig"),
151152
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
152153
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
153154
("flaubert", "FlaubertConfig"),
@@ -585,6 +586,7 @@
585586
("falcon3", "Falcon3"),
586587
("falcon_h1", "FalconH1"),
587588
("falcon_mamba", "FalconMamba"),
589+
("fast_vlm", "FastVlm"),
588590
("fastspeech2_conformer", "FastSpeech2Conformer"),
589591
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
590592
("flan-t5", "FLAN-T5"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
151151
("falcon", "FalconModel"),
152152
("falcon_h1", "FalconH1Model"),
153153
("falcon_mamba", "FalconMambaModel"),
154+
("fast_vlm", "FastVlmModel"),
154155
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
155156
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
156157
("flaubert", "FlaubertModel"),
@@ -996,6 +997,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
996997
("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
997998
("emu3", "Emu3ForConditionalGeneration"),
998999
("evolla", "EvollaForProteinText2Text"),
1000+
("fast_vlm", "FastVlmForConditionalGeneration"),
9991001
("florence2", "Florence2ForConditionalGeneration"),
10001002
("fuyu", "FuyuForCausalLM"),
10011003
("gemma3", "Gemma3ForConditionalGeneration"),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_fast_vlm import *
22+
from .modeling_fast_vlm import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2+
# This file was automatically generated from src/transformers/models/fast_vlm/modular_fast_vlm.py.
3+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
4+
# the file from the modular. If any change should be done, please apply the change to the
5+
# modular_fast_vlm.py file directly. One of our CI enforces this.
6+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
# Copyright 2025 The HuggingFace Team. All rights reserved.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from ...configuration_utils import PreTrainedConfig
22+
from ..auto import CONFIG_MAPPING, AutoConfig
23+
24+
25+
class FastVlmConfig(PreTrainedConfig):
26+
r"""
27+
This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a
28+
FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
29+
with the defaults will yield the same configuration as the one of FastVLM-7B.
30+
31+
e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B)
32+
33+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34+
documentation from [`PretrainedConfig`] for more information.
35+
36+
Args:
37+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`):
38+
The config object or dictionary of the vision backbone.
39+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
40+
The config object or dictionary of the text backbone.
41+
image_token_id (`int`, *optional*, defaults to 151646):
42+
The image token index to encode the image prompt.
43+
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
44+
The activation function used by the multimodal projector.
45+
vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
46+
The feature selection strategy used to select the vision feature from the vision backbone.
47+
Only "full" supported.
48+
vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
49+
The index of the layer to select the vision feature. If multiple indices are provided,
50+
the vision feature of the corresponding indices will be concatenated to form the
51+
vision features. Only -1 supported.
52+
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
53+
Whether to use bias in the multimodal projector.
54+
55+
Example:
56+
57+
```python
58+
>>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig
59+
60+
>>> # Initializing a FastVLM-7B style configuration
61+
>>> configuration = FastVlmConfig()
62+
63+
>>> # Initializing a model from the FastVLM-7B style configuration
64+
>>> model = FastVlmForConditionalGeneration(configuration)
65+
66+
>>> # Accessing the model configuration
67+
>>> configuration = model.config
68+
```"""
69+
70+
model_type = "fast_vlm"
71+
attribute_map = {
72+
"image_token_id": "image_token_index",
73+
}
74+
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
75+
76+
def __init__(
77+
self,
78+
vision_config=None,
79+
text_config=None,
80+
image_token_id=151646,
81+
projector_hidden_act="gelu",
82+
vision_feature_select_strategy="full",
83+
vision_feature_layer=-1,
84+
multimodal_projector_bias=True,
85+
**kwargs,
86+
):
87+
self.image_token_id = image_token_id
88+
self.projector_hidden_act = projector_hidden_act
89+
90+
if vision_feature_select_strategy != "full":
91+
raise ValueError(
92+
f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM."
93+
)
94+
95+
if vision_feature_layer != -1:
96+
raise ValueError(
97+
f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM."
98+
)
99+
100+
self.vision_feature_select_strategy = vision_feature_select_strategy
101+
self.vision_feature_layer = vision_feature_layer
102+
103+
if isinstance(vision_config, dict):
104+
vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper")
105+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
106+
elif vision_config is None:
107+
vision_config = CONFIG_MAPPING["timm_wrapper"](
108+
architecture="fastvit_mci3",
109+
do_pooling=True,
110+
global_pool="avg",
111+
hidden_size=3072,
112+
initializer_range=0.02,
113+
model_args={"inference_mode": True},
114+
)
115+
116+
self.vision_config = vision_config
117+
118+
if isinstance(text_config, dict):
119+
text_config["model_type"] = text_config.get("model_type", "qwen2")
120+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
121+
elif text_config is None:
122+
text_config = CONFIG_MAPPING["qwen2"](
123+
hidden_size=3584,
124+
vocab_size=152128,
125+
intermediate_size=18944,
126+
num_attention_heads=28,
127+
num_key_value_heads=4,
128+
num_hidden_layers=28,
129+
)
130+
131+
self.text_config = text_config
132+
self.multimodal_projector_bias = multimodal_projector_bias
133+
134+
super().__init__(**kwargs)
135+
136+
137+
__all__ = ["FastVlmConfig"]

0 commit comments

Comments
 (0)