Skip to content

Conversation

@sywangyi
Copy link
Contributor

@sywangyi sywangyi commented Nov 5, 2025

…y::ShieldGemma2IntegrationTest::test_model

@ydshieh

…y::ShieldGemma2IntegrationTest::test_model

Signed-off-by: Wang, Yi <[email protected]>
@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 5, 2025

export RUN_SLOW=true
pytest -rA tests/models/shieldgemma2/test_modeling_shieldgemma2.py::ShieldGemma2IntegrationTest::test_model


self = <tests.models.shieldgemma2.test_modeling_shieldgemma2.ShieldGemma2IntegrationTest testMethod=test_model>

    def test_model(self):
        model_id = "google/shieldgemma-2-4b-it"

        processor = ShieldGemma2Processor.from_pretrained(model_id, padding_side="left")
        url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))

        model = ShieldGemma2ForImageClassification.from_pretrained(
            model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
        )

        inputs = processor(images=[image], return_tensors="pt").to(torch_device)
>       output = model(**inputs)
                 ^^^^^^^^^^^^^^^

tests/models/shieldgemma2/test_modeling_shieldgemma2.py:56:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/models/shieldgemma2/modeling_shieldgemma2.py:121: in forward
    outputs = self.model(
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/models/gemma3/modeling_gemma3.py:1156: in forward
    outputs = self.model(
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/utils/generic.py:757: in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/models/gemma3/modeling_gemma3.py:982: in forward
    inputs_embeds = self.get_input_embeddings()(llm_input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/models/gemma3/modeling_gemma3.py:105: in forward
    return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/sparse.py:192: in forward
    return F.embedding(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

input = [[0, 0, 0, 0, 0, 2, ...], [0, 0, 0, 0, 0, 0, ...], [2, 2, 105, 2364, 109, 255999, ...]]
weight = Parameter containing:
tensor([[ 0.0114,  0.0023, -0.0023,  ..., -0.0085,  0.0020, -0.0078],
        [-0.0260,  0.0030,...,  0.0014, -0.0002,  ..., -0.0009, -0.0008, -0.0010]],
       device='cuda:0', dtype=torch.float16, requires_grad=True)
padding_idx = 0, max_norm = None, norm_type = 2.0, scale_grad_by_freq = False, sparse = False

    def embedding(
        input: Tensor,
        weight: Tensor,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
    ) -> Tensor:
        r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.

        This module is often used to retrieve word embeddings using indices.
        The input to the module is a list of indices, and the embedding matrix,
        and the output is the corresponding word embeddings.

        See :class:`torch.nn.Embedding` for more details.

        .. note::
            Note that the analytical gradients of this function with respect to
            entries in :attr:`weight` at the row specified by :attr:`padding_idx`
            are expected to differ from the numerical ones.

        .. note::
            Note that `:class:`torch.nn.Embedding` differs from this function in
            that it initializes the row of :attr:`weight` specified by
            :attr:`padding_idx` to all zeros on construction.

        Args:
            input (LongTensor): Tensor containing indices into the embedding matrix
            weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
                and number of columns equal to the embedding size
            padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
                                         therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
                                         i.e. it remains as a fixed "pad".
            max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                                        is renormalized to have norm :attr:`max_norm`.
                                        Note: this will modify :attr:`weight` in-place.
            norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
            scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
                                                    the words in the mini-batch. Default ``False``.
            sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
                                     :class:`torch.nn.Embedding` for more details regarding sparse gradients.

        Shape:
            - Input: LongTensor of arbitrary shape containing the indices to extract
            - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
              where V = maximum index + 1 and embedding_dim = the embedding size
            - Output: `(*, embedding_dim)`, where `*` is the input shape

        Examples::

            >>> # a batch of 2 samples of 4 indices each
            >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
            >>> # an embedding matrix containing 10 tensors of size 3
            >>> embedding_matrix = torch.rand(10, 3)
            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
            >>> F.embedding(input, embedding_matrix)
            tensor([[[ 0.8490,  0.9625,  0.6753],
                     [ 0.9666,  0.7761,  0.6108],
                     [ 0.6246,  0.9751,  0.3618],
                     [ 0.4161,  0.2419,  0.7383]],

                    [[ 0.6246,  0.9751,  0.3618],
                     [ 0.0237,  0.7794,  0.0528],
                     [ 0.9666,  0.7761,  0.6108],
                     [ 0.3385,  0.8612,  0.1867]]])

            >>> # example with padding_idx
            >>> weights = torch.rand(10, 3)
            >>> weights[0, :].zero_()
            >>> embedding_matrix = weights
            >>> input = torch.tensor([[0, 2, 0, 5]])
            >>> F.embedding(input, embedding_matrix, padding_idx=0)
            tensor([[[ 0.0000,  0.0000,  0.0000],
                     [ 0.5609,  0.5384,  0.8720],
                     [ 0.0000,  0.0000,  0.0000],
                     [ 0.6262,  0.2438,  0.7471]]])
        """
        if has_torch_function_variadic(input, weight):
            return handle_torch_function(
                embedding,
                (input, weight),
                input,
                weight,
                padding_idx=padding_idx,
                max_norm=max_norm,
                norm_type=norm_type,
                scale_grad_by_freq=scale_grad_by_freq,
                sparse=sparse,
            )
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < weight.size(0), (
                    "Padding_idx must be within num_embeddings"
                )
            elif padding_idx < 0:
                assert padding_idx >= -weight.size(0), (
                    "Padding_idx must be within num_embeddings"
                )
                padding_idx = weight.size(0) + padding_idx
        else:
            padding_idx = -1
        if max_norm is not None:
            # Note [embedding_renorm contiguous]
            # `embedding_renorm_` will call .contiguous() on input anyways, so we
            # call it here and take advantage of the improved locality in the
            # `embedding` call below too.
            input = input.contiguous()
            # Note [embedding_renorm set_grad_enabled]
            # XXX: equivalent to
            # with torch.no_grad():
            #   torch.embedding_renorm_
            # remove once script supports set_grad_enabled
            _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
>       return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list

/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/functional.py:2542: TypeError
-------------------------------------------------------------- Captured stderr call --------------------------------------------------------------
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 42153.81it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.12s/it]
Some weights of ShieldGemma2ForImageClassification were not initialized from the model checkpoint at google/shieldgemma-2-4b-it and are newly initialized: ['model.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
================================================================ warnings summary ================================================================
../../../disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/backends/__init__.py:46
  /mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/backends/__init__.py:46: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
    self.setter(val)

tests/models/shieldgemma2/test_modeling_shieldgemma2.py::ShieldGemma2IntegrationTest::test_model
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

tests/models/shieldgemma2/test_modeling_shieldgemma2.py::ShieldGemma2IntegrationTest::test_model
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================ short test summary info =============================================================
FAILED tests/models/shieldgemma2/test_modeling_shieldgemma2.py::ShieldGemma2IntegrationTest::test_model - TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list
========================================================= 1 failed, 3 warnings in 52.17s =========================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 5, 2025

the reason of the crash is caused by the output of _merge_kwargs.
input to _merge_kwargs for this case.
kwargs = {'return_tensors': 'pt', 'images_kwargs': {}, 'text_kwargs': {'padding': True, 'padding_side': 'left'}}

output wo the PR.
output_kwargs = {'text_kwargs': {'padding': True, 'return_mm_token_type_ids': True, 'padding_side': 'left'}, 'images_kwargs': {'do_convert_rgb': True, 'do_pan_and_scan': False, 'pan_and_scan_min_crop_size': 256, 'pan_and_scan_max_num_crops': 4, 'pan_and_scan_min_ratio_to_activate': 1.2}, 'audio_kwargs': {'return_tensors': 'pt'}, 'videos_kwargs': {'return_tensors': 'pt'}}

the "return_tensors":"pt" is missing in value of 'text_kwargs' and ''images_kwargs'.

output w the PR

output_kwargs = {'text_kwargs': {'padding': True, 'return_mm_token_type_ids': True, 'padding_side': 'left', 'return_tensors': 'pt'}, 'images_kwargs': {'do_convert_rgb': True, 'do_pan_and_scan': False, 'pan_and_scan_min_crop_size': 256, 'pan_and_scan_max_num_crops': 4, 'pan_and_scan_min_ratio_to_activate': 1.2, 'return_tensors': 'pt'}, 'audio_kwargs': {'return_tensors': 'pt'}, 'videos_kwargs': {'return_tensors': 'pt'}}

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 5, 2025

Thank you @sywangyi . Confirmed the issue and the PR makes the test passing.

I will take a close look on the changes 🙏

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2025

@zucchini-nlp This issue is introduced in

🚨 [unbloating] unify TypedDict usage in processing (#40931) - commit on main is 5339f72

With Its parent commit 42bcc81, the test passed.

Could you take a look and see if the changes in this PR by @sywangyi is the best/correct way to handle it?

Thank you.

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 6, 2025

@zucchini-nlp This issue is introduced in

🚨 [unbloating] unify TypedDict usage in processing (#40931) - commit on main is 5339f72

With Its parent commit 42bcc81, the test passed.

Could you take a look and see if the changes in this PR is the best/correct way to handle it?

Thank you.

sorry, which PR you mean? all these PR seems to have been merged in Oct.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2025

@sywangyi

It's from the PR #40931

@zucchini-nlp
Copy link
Member

kwargs = {'return_tensors': 'pt', 'images_kwargs': {}, 'text_kwargs': {'padding': True, 'padding_side': 'left'}}

I don't think this is a valid input to the processor since some of the kwargs are structured like text/images and the rest are not. The common kwargs can be structured as {'common_kwargs': 'return_tensors': 'pt',}

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2025

@zucchini-nlp So do you have a suggestion to change the code block below to make it work?

        model_id = "google/shieldgemma-2-4b-it"

        processor = ShieldGemma2Processor.from_pretrained(model_id, padding_side="left")
        url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))

        inputs = processor(images=[image], return_tensors="pt").to(torch_device)
>       output = model(**inputs)

@zucchini-nlp
Copy link
Member

Yep, we can instead fix it in ShieldGemma's processor class by passing return_tensors as "common_kwargs"

        common_kwargs = kwargs.setdefault("common_kwargs", {})
        if "return_tensors" in kwargs:
            common_kwargs["return_tensors"] = kwargs.pop("return_tensors")

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 7, 2025

thanks, update the PR

@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: shieldgemma2

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, can be merged after Yih-Dar confirms that CI is ✔️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants