Skip to content

Conversation

@tomaarsen
Copy link
Member

What does this PR do?

  • Add return_dict to get_text_features & get_image_features methods to allow returning 'BaseModelOutputWithPooling'

Fixes #42401

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods:

  • 2d outputs,
  • 3d outputs,
  • lists of 2d outputs (due to non-matching shapes),
  • existing 'return_attentions' resulting in returning 2-tuple,
  • existing 'return_dict' resulting in returning 3-tuples (???),
  • high quality image embeddings,
  • low quality image embeddings,
  • deepstack image embeddings,
  • etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause. cc @zucchini-nlp I'm curious about your thoughts here. When I did some preliminary research, I only ran into a handful of cases, and I figured we'd be able to reformat them all into one format, but I'm not sure anymore. I added # NOTE: @Tom ... where I figured we might have big problems with standardisation.

For get_text_features it's a lot simpler, there's only one architecture (blip-2) that differs from all others.

I haven't started on get_audio_features and get_video_features, but there's not too much of a point if we can't get get_image_features normalized.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @ArthurZucker @Cyrilvallez

  • Tom Aarsen

…ModelOutputWithPooling'

Added to all architectures except blip-2, which has a much different structure here. It uses 'Blip2TextModelWithProjection' to get these embeddings/features, but this class isn't as simple to use
…eModelOutputWithPooling'

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods. 2d outputs, 3d outputs, lists of 2d outputs (due to non-matching shapes), existing 'return_attentions' resulting in returning 2-tuple, existing 'return_dict' resulting in returning 3-tuples (???), high quality image embeddings, low quality image embeddings, deepstack image embeddings, etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause.
@github-actions
Copy link
Contributor

github-actions bot commented Dec 2, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, align, altclip, aria, aya_vision, blip, blip_2, chameleon, chinese_clip, clap, clip, clipseg, clvp, cohere2_vision, deepseek_vl, deepseek_vl_hybrid

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this internally and decided to add last_hidden_states to all models as the last state from vision block. The pooled embeddings will stay of different shapes as is

For the last hidden state the shapes are already more standardized, with a few major options. The only special cases might be qwen-like models where each image encoding has different sequence length and thus the outputs are concatenated as length*dim

vision_embeddings = self.get_input_embeddings()(image_tokens)
return vision_embeddings
image_embeddings = self.get_input_embeddings()(image_tokens)
image_features = image_embeddings.mean(dim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking the mean is not what we want for VLM. They are supposed to return image embeddings in the format that can be concatenated with text embeddings

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. This was just a quick test to experiment with the shapes. I didn't realise I kept it in.


if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=image_embeddings,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with chameleon it is a bit vague. The vision quantizer could return hidden_states before quantizing them which i believe is the last hidden state we want

]
image_features = self.get_input_embeddings()(image_tokens)
image_features = torch.split(image_features, split_sizes)
# NOTE: @Tom Not easily converted to the standard format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, same as chameleon. We would first need to start returning hidden states from a VQ-module

image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
# NOTE: @Tom Not easily converted to the standard format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as qwen-vl models with the last hidden state being of shape bs, len*pooled_dim. The visual block returns only pooled outputs iirc, so we might need to also change the vision block

pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# NOTE: @Tom perhaps we should just raise an error here instead?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fn should be simply removed, because model doesn't work with images. It was a bad copy from modular at the time 🫠

return BaseModelOutputWithPooling(
last_hidden_state=vision_model_output.last_hidden_state,
pooler_output=image_embeds,
attentions=projection_attentions, # TODO: @Tom does this match expectations?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd say not really, since these look like attentions of vision-pooling module. Very model specific, most poolers I've seen aren't attention based

if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=image_embeds,
# pooler_output=image_features, # NOTE: @Tom no pooled embeddings here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here, image_embeds are actually pooled embeddings and the last hidden state is not returned from visual

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The get_(text|image|audio|video)_features methods have inconsistent output formats, needs aligning for Sentence Transformers

3 participants