diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c3036b8a3973..98afe8be7339 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1072,6 +1072,8 @@ title: LayoutXLM - local: model_doc/lfm2_vl title: LFM2-VL + - local: model_doc/lightonocr + title: LightOnOCR - local: model_doc/lilt title: LiLT - local: model_doc/llama4 diff --git a/docs/source/en/model_doc/lightonocr.md b/docs/source/en/model_doc/lightonocr.md new file mode 100644 index 000000000000..7f61c872517b --- /dev/null +++ b/docs/source/en/model_doc/lightonocr.md @@ -0,0 +1,66 @@ + +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-18.* + +# LightOnOCR + + +**LightOnOCR** is a compact, end-to-end vision–language model for Optical Character Recognition (OCR) and document understanding. It achieves state-of-the-art accuracy in its weight class while being several times faster and cheaper than larger general-purpose VLMs. + +πŸ“ **[Read the full blog post](https://huggingface.co/blog/lightonai/lightonocr/)** | πŸ““ **[Finetuning notebook](https://colab.research.google.com/drive/1WjbsFJZ4vOAAlKtcCauFLn_evo5UBRNa?usp=sharing)** + +**Model Overview** + +LightOnOCR combines a Vision Transformer encoder(Pixtral-based) with a lightweight text decoder(Qwen3-based) distilled from high-quality open VLMs. It is optimized for document parsing tasks, producing accurate, layout-aware text extraction from high-resolution pages. + + + + +## LightOnOCRConfig + +[[autodoc]] LightOnOCRConfig + +## LightOnOCRTextConfig + +[[autodoc]] LightOnOCRTextConfig + +## LightOnOCRVisionConfig + +[[autodoc]] LightOnOCRVisionConfig + +## LightOnOCRProcessor + +[[autodoc]] LightOnOCRProcessor + - __call__ + +## LightOnOCRTextModel + +[[autodoc]] LightOnOCRTextModel + - forward + +## LightOnOCRVisionModel + +[[autodoc]] LightOnOCRVisionModel + - forward + +## LightOnOCRModel + +[[autodoc]] LightOnOCRModel + - forward + +## LightOnOCRForConditionalGeneration + +[[autodoc]] LightOnOCRForConditionalGeneration + - forward diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c55980e471c7..cf0b514d5d7a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -229,6 +229,7 @@ ("lfm2_moe", "Lfm2MoeConfig"), ("lfm2_vl", "Lfm2VlConfig"), ("lightglue", "LightGlueConfig"), + ("lightonocr", "LightOnOCRConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), ("llama4", "Llama4Config"), @@ -665,6 +666,7 @@ ("lfm2_moe", "Lfm2Moe"), ("lfm2_vl", "Lfm2Vl"), ("lightglue", "LightGlue"), + ("lightonocr", "LightOnOCR"), ("lilt", "LiLT"), ("llama", "LLaMA"), ("llama2", "Llama2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 22985f413341..c6f6cc7701be 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -229,6 +229,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("lfm2_moe", "Lfm2MoeModel"), ("lfm2_vl", "Lfm2VlModel"), ("lightglue", "LightGlueForKeypointMatching"), + ("lightonocr", "LightOnOCRModel"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), ("llama4", "Llama4ForConditionalGeneration"), @@ -1004,6 +1005,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), ("lfm2_vl", "Lfm2VlForConditionalGeneration"), + ("lightonocr", "LightOnOCRForConditionalGeneration"), ("llama4", "Llama4ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 9e6f4e66ff4d..8f07bc96d37d 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -96,6 +96,7 @@ ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("lfm2_vl", "Lfm2VlProcessor"), + ("lightonocr", "LightOnOCRProcessor"), ("llama4", "Llama4Processor"), ("llava", "LlavaProcessor"), ("llava_next", "LlavaNextProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index be0a1f0dd754..cde17ec7b62c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -368,6 +368,13 @@ ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), ("lfm2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("lfm2_vl", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ( + "lightonocr", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), ( "llama", diff --git a/src/transformers/models/lightonocr/__init__.py b/src/transformers/models/lightonocr/__init__.py new file mode 100644 index 000000000000..fc59bbafb058 --- /dev/null +++ b/src/transformers/models/lightonocr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lightonocr import * + from .modeling_lightonocr import * + from .processing_lightonocr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/lightonocr/configuration_lightonocr.py b/src/transformers/models/lightonocr/configuration_lightonocr.py new file mode 100644 index 000000000000..85c0a3e7377c --- /dev/null +++ b/src/transformers/models/lightonocr/configuration_lightonocr.py @@ -0,0 +1,367 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightonocr/modular_lightonocr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightonocr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Any, Optional + +from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params + + +class LightOnOCRVisionConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightOnOCRVisionModel`]. It is used to instantiate an + LightOnOCR vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to the vision encoder used by LightOnOCR-12B. + + e.g. [lightonocr-hf/lightonocr-9b](https://huggingface.co/lightonocr-hf/lightonocr-9b) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels in the input images. + image_size (`int`, *optional*, defaults to 1024): + Max dimension of the input images. + patch_size (`int`, *optional*, defaults to 16): + Size of the image patches. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + Activation function used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the attention layers. + rope_parameters (`RopeParameters`, *optional*): + The RopeParameters + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import LightOnOCRVisionModel, LightOnOCRVisionConfig + + >>> # Initializing a LightOnOCR-12B style configuration + >>> config = LightOnOCRVisionConfig() + + >>> # Initializing a model (with randomly initialized weights) from the configuration + >>> model = LightOnOCRVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "lightonocr_vision" + + def __init__( + self, + hidden_size: Optional[int] = 1024, + intermediate_size: Optional[int] = 4096, + num_hidden_layers: Optional[int] = 24, + num_attention_heads: Optional[int] = 16, + num_channels: Optional[int] = 3, + image_size: Optional[int] = 1024, + patch_size: Optional[int] = 16, + hidden_act: Optional[str] = "gelu", + attention_dropout: Optional[float] = 0.0, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + initializer_range: Optional[float] = 0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.head_dim = hidden_size // num_attention_heads + self.initializer_range = initializer_range + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 10000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + +class LightOnOCRTextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightOnOCRTextModel`]. It is used to instantiate a + LightOnOCRText model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + LightOnOCRText-8B [Qwen/LightOnOCRText-8B](https://huggingface.co/Qwen/LightOnOCRText-8B). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the LightOnOCRText model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LightOnOCRTextModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + layer_types (`list`, *optional*): + Attention pattern for each layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import LightOnOCRTextModel, LightOnOCRTextConfig + + >>> # Initializing a LightOnOCRText style configuration + >>> configuration = LightOnOCRTextConfig() + + >>> # Initializing a model from the LightOnOCRText-8B style configuration + >>> model = LightOnOCRTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "lightonocr_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `LightOnOCRText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 151936, + hidden_size: Optional[int] = 4096, + intermediate_size: Optional[int] = 22016, + num_hidden_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_key_value_heads: Optional[int] = 32, + head_dim: Optional[int] = 128, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 32768, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + tie_word_embeddings: Optional[bool] = False, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + attention_bias: Optional[bool] = False, + use_sliding_window: Optional[bool] = False, + sliding_window: Optional[int] = 4096, + max_window_layers: Optional[int] = 28, + layer_types: Optional[list[str]] = None, + attention_dropout: Optional[float] = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 10000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class LightOnOCRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightOnOCRForConditionalGeneration`]. It is used to instantiate a + LightOnOCR model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will yield + a similar configuration to that of the LightOnOCR [lightonocr-hf/lightonocr-9b](https://huggingface.co/lightonocr-hf/lightonocr-9b) architecture. + + Args: + spatial_merge_size (`int`, *optional*, defaults to 2): + The size of spatial merging for image patches. + image_token_id (`int`, *optional*, defaults to 151655): + The id of the image token in the vocabulary. + vision_config (`dict` or `LightOnOCRVisionConfig`, *optional*): + Custom vision configuration or dictionary with vision configuration values. + text_config (`dict` or `LightOnOCRTextConfig`, *optional*): + Custom text configuration or dictionary with text configuration values. + + Example: + + ```python + >>> from transformers import LightOnOCRConfig, LightOnOCRForConditionalGeneration + + >>> # Initializing a LightOnOCR configuration + >>> configuration = LightOnOCRConfig() + + >>> # Initializing a model from the configuration + >>> model = LightOnOCRForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightonocr" + sub_configs = {"text_config": LightOnOCRTextConfig, "vision_config": LightOnOCRVisionConfig} + + def __init__( + self, + spatial_merge_size: int = 2, + image_token_id: int = 151655, + vision_config: Optional[dict[str, Any]] = None, + text_config: Optional[dict[str, Any]] = None, + **kwargs, + ): + self.spatial_merge_size = spatial_merge_size + self.image_token_id = image_token_id + + if vision_config is None: + self.vision_config = LightOnOCRVisionConfig( + attention_dropout=0, + head_dim=64, + hidden_act="silu", + hidden_size=1024, + image_size=1540, + initializer_range=0.02, + intermediate_size=4096, + model_type="pixtral", + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + rope_theta=10000, + ) + elif isinstance(vision_config, PretrainedConfig): + self.vision_config = vision_config + else: + self.vision_config = LightOnOCRVisionConfig(**vision_config) + + if text_config is None: + self.text_config = LightOnOCRTextConfig( + attention_dropout=0, + head_dim=128, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=3072, + max_position_embeddings=40960, + model_type="qwen3", + num_attention_heads=16, + num_hidden_layers=28, + num_key_value_heads=8, + rms_norm_eps=1e-6, + rope_theta=1000000, + sliding_window=None, + use_cache=True, + use_sliding_window=False, + vocab_size=151936, + ) + elif isinstance(text_config, PretrainedConfig): + self.text_config = text_config + else: + self.text_config = LightOnOCRTextConfig(**text_config) + + super().__init__(**kwargs) + + +__all__ = ["LightOnOCRConfig", "LightOnOCRTextConfig", "LightOnOCRVisionConfig"] diff --git a/src/transformers/models/lightonocr/modeling_lightonocr.py b/src/transformers/models/lightonocr/modeling_lightonocr.py new file mode 100644 index 000000000000..2a361d4ea0de --- /dev/null +++ b/src/transformers/models/lightonocr/modeling_lightonocr.py @@ -0,0 +1,1278 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightonocr/modular_lightonocr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightonocr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from collections.abc import Callable +from typing import Optional, Union + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torch_available +from ...utils.generic import TransformersKwargs, check_model_inputs +from .configuration_lightonocr import LightOnOCRConfig, LightOnOCRTextConfig, LightOnOCRVisionConfig + + +if is_torch_available(): + import torch + from torch import nn + + +@use_kernel_forward_from_hub("RMSNorm") +class LightOnOCRTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + LightOnOCRTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LightOnOCRPatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches. + """ + + def __init__(self, config: LightOnOCRConfig): + super().__init__() + self.config = config + self.hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + + self.patch_size = config.vision_config.patch_size + + self.merging_layer = nn.Linear(self.hidden_size * self.spatial_merge_size**2, self.hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: Union[torch.Tensor, list]) -> torch.Tensor: + image_sizes_in_patches = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [patch_height * patch_width for patch_height, patch_width in image_sizes_in_patches] + hidden_dim = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # reshape image_tokens into a 2D grid + patch_height, patch_width = image_sizes_in_patches[image_index] + # shape [num_patches, hidden_dim] -> [1, hidden_dim, patch_height, patch_width] + image_grid = image_tokens.view(patch_height, patch_width, hidden_dim).permute(2, 0, 1).unsqueeze(0) + # shape [1, hidden_dim, patch_height, patch_width] -> [patch_height // sms * patch_width // sms, hidden_dim * sms**2] + # sms = spatial_merge_size + # Note: patch_height and patch_width are guaranteed to be divisible by sms because the image processor + # resizes images to multiples of effective_patch_size (patch_size * spatial_merge_size) + grid = torch.nn.functional.unfold( + image_grid, + kernel_size=self.spatial_merge_size, + stride=self.spatial_merge_size, + ) + # shape [patch_height // sms * patch_width // sms, hidden_dim * sms**2] -> [patch_height // sms * patch_width // sms, hidden_dim * sms**2] + grid = grid.view(hidden_dim * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class LightOnOCRVisionProjector(nn.Module): + def __init__(self, config: LightOnOCRConfig): + super().__init__() + self.config = config + + self.norm = LightOnOCRTextRMSNorm(config.vision_config.hidden_size, eps=1e-6) + self.patch_merger = LightOnOCRPatchMerger(config) + self.act = nn.GELU() + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=False, + ) + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: Union[torch.Tensor, list]): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LightOnOCRPreTrainedModel(PreTrainedModel): + config_class = LightOnOCRConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LightOnOCRVisionProjector", "LightOnOCRPatchMerger"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +# Vision model components - explicitly renamed from Pixtral +class LightOnOCRVisionPreTrainedModel(PreTrainedModel): + config_class = LightOnOCRVisionConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _no_split_modules = ["LightOnOCRVisionAttentionLayer"] + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def vision_eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def vision_rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def vision_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos.unsqueeze(unsqueeze_dim).to(q.device) + sin = sin.unsqueeze(unsqueeze_dim).to(q.device) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (vision_rotate_half(q) * sin) + k_embed = (k * cos) + (vision_rotate_half(k) * sin) + return q_embed, k_embed + + +class LightOnOCRAttention(nn.Module): + """ + Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = vision_apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attention_interface: Callable = vision_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # Since we use packing, if flash_attention_2 is selected we rely on position_ids + if self.config._attn_implementation == "flash_attention_2": + kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +class LightOnOCRRotaryEmbedding(nn.Module): + """ + The key with lightonocr embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. + + What you output is of dimension (batch, height * width, dim) with dim the embed dim. + + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: LightOnOCRVisionConfig, device=None, layer_type=None): + super().__init__() + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + raise ValueError( + f"{self.__class__.__name__} does not support non-default RoPE, but got `rope_type={self.rope_type}`" + ) + + inv_freq, attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[LightOnOCRVisionConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Here is the diff from Llama RoPE + max_patches_per_side = config.image_size // config.patch_size + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + freqs = self.inv_freq[position_ids] + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + emb = freqs + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LightOnOCRMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class LightOnOCRRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LightOnOCRRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LightOnOCRAttentionLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.attention_norm = LightOnOCRRMSNorm(config.hidden_size, eps=1e-5) + self.feed_forward = LightOnOCRMLP(config) + self.attention = LightOnOCRAttention(config) + self.ffn_norm = LightOnOCRRMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class LightOnOCRTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(LightOnOCRAttentionLayer(config)) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embeddings which serve as input to the Transformer. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +@auto_docstring( + custom_intro=""" + The vision encoder of LightOnOCR, based on Pixtral vision architecture. + """ +) +class LightOnOCRVisionModel(LightOnOCRPreTrainedModel): + base_model_prefix = "vision_encoder" + config_class = LightOnOCRVisionConfig + + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.patch_size = config.patch_size + self.ln_pre = LightOnOCRRMSNorm(config.hidden_size, eps=1e-5) + self.transformer = LightOnOCRTransformer(config) + self.patch_positional_embedding = LightOnOCRRotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.patch_conv + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_sizes is None: + batch_size, _, height, width = pixel_values.shape + image_sizes = [(height, width)] * batch_size + + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ) + kwargs["position_ids"] = position_ids + + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + if self.config._attn_implementation == "flash_attention_2": + # We only rely on position_ids when using flash_attention_2 + attention_mask = None + else: + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + + return self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + **kwargs, + ) + + +class LightOnOCRTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class LightOnOCRTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: LightOnOCRTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[LightOnOCRTextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LightOnOCRTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LightOnOCRTextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = LightOnOCRTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = LightOnOCRTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightOnOCRTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LightOnOCRTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LightOnOCRTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = LightOnOCRTextMLP(config) + self.input_layernorm = LightOnOCRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LightOnOCRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class LightOnOCRTextPreTrainedModel(PreTrainedModel): + config: LightOnOCRTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LightOnOCRTextDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": LightOnOCRTextDecoderLayer, + "attentions": LightOnOCRTextAttention, + } + + +@auto_docstring( + custom_intro=""" + The language model of LightOnOCR, based on Qwen3 architecture. + """ +) +class LightOnOCRTextModel(LightOnOCRTextPreTrainedModel): + config_class = LightOnOCRTextConfig + + def __init__(self, config: LightOnOCRTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LightOnOCRTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LightOnOCRTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LightOnOCRTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + +class LightOnOCRModel(LightOnOCRPreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: LightOnOCRConfig + + def __init__(self, config: LightOnOCRConfig): + super().__init__(config) + + self.vision_encoder = LightOnOCRVisionModel._from_config(config.vision_config) + + self.vision_projection = LightOnOCRVisionProjector(config) + + self.language_model = LightOnOCRTextModel._from_config(config.text_config) + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_image_features(self, pixel_values: torch.Tensor, image_sizes: Union[torch.Tensor, list]): + """ + Obtains image features from the vision encoder and projection. + + Args: + pixel_values: Image tensors + image_sizes: Tensor or list of (height, width) pairs for each image + + Returns: + List of image feature tensors, one per image + """ + visual_features = self.vision_encoder(pixel_values, image_sizes=image_sizes).last_hidden_state + + image_features = self.vision_projection(visual_features.squeeze(0), image_sizes) + + # Split features per image based on the effective patch size + downsample_ratio = self.config.vision_config.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features, split_sizes) + + return image_features + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + @property + def vision_model(self): + """Alias for vision_encoder to match standard composite model naming.""" + return self.vision_encoder + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_sizes: Optional[Union[torch.Tensor, list]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if inputs_embeds is None: + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + if pixel_values is not None: + # Note: image_sizes is automatically expanded by the generation framework during beam search + image_features_list = self.get_image_features(pixel_values, image_sizes) + image_features = torch.cat(image_features_list, dim=0) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = self.get_placeholder_mask(input_ids, inputs_embeds, image_features) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + outputs = self.language_model( + input_ids=None, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + **kwargs, + ) + + return outputs + + +class LightOnOCRForConditionalGeneration(LightOnOCRPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + config_class = LightOnOCRConfig + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: LightOnOCRConfig): + super().__init__(config) + self.model = LightOnOCRModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model.language_model + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_sizes: Optional[Union[torch.Tensor, list]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0] + logits: torch.Tensor = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config.text_config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = kwargs.get("image_sizes") + + return model_inputs + + @property + def language_model(self): + return self.model.language_model + + @property + def vision_encoder(self): + return self.model.vision_encoder + + @property + def vision_model(self): + """Alias for vision_encoder to match standard composite model naming.""" + return self.model.vision_encoder + + +__all__ = [ + "LightOnOCRPreTrainedModel", + "LightOnOCRVisionModel", + "LightOnOCRVisionPreTrainedModel", + "LightOnOCRTextModel", + "LightOnOCRTextPreTrainedModel", + "LightOnOCRForConditionalGeneration", + "LightOnOCRModel", +] diff --git a/src/transformers/models/lightonocr/modular_lightonocr.py b/src/transformers/models/lightonocr/modular_lightonocr.py new file mode 100644 index 000000000000..c34f8e8daf8f --- /dev/null +++ b/src/transformers/models/lightonocr/modular_lightonocr.py @@ -0,0 +1,793 @@ +from collections.abc import Callable +from typing import Any, Optional, Union + +import numpy as np + +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...generation import GenerationMixin +from ...image_utils import ImageInput +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import ( + MultiModalData, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring, is_torch_available, is_vision_available +from ...utils.generic import TransformersKwargs, check_model_inputs +from ..pixtral.configuration_pixtral import PixtralVisionConfig +from ..pixtral.image_processing_pixtral import get_resize_output_image_size +from ..pixtral.modeling_pixtral import ( + PixtralAttention, + PixtralVisionModel, +) +from ..qwen3.configuration_qwen3 import Qwen3Config +from ..qwen3.modeling_qwen3 import ( + Qwen3Model, + Qwen3RMSNorm, +) + + +if is_torch_available(): + import torch + from torch import nn + +if is_vision_available(): + from ..pixtral.image_processing_pixtral import get_resize_output_image_size + + +class LightOnOCRVisionConfig(PixtralVisionConfig): + model_type = "lightonocr_vision" + pass + + +class LightOnOCRTextConfig(Qwen3Config): + model_type = "lightonocr_text" + + +class LightOnOCRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightOnOCRForConditionalGeneration`]. It is used to instantiate a + LightOnOCR model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will yield + a similar configuration to that of the LightOnOCR [lightonocr-hf/lightonocr-9b](https://huggingface.co/lightonocr-hf/lightonocr-9b) architecture. + + Args: + spatial_merge_size (`int`, *optional*, defaults to 2): + The size of spatial merging for image patches. + image_token_id (`int`, *optional*, defaults to 151655): + The id of the image token in the vocabulary. + vision_config (`dict` or `LightOnOCRVisionConfig`, *optional*): + Custom vision configuration or dictionary with vision configuration values. + text_config (`dict` or `LightOnOCRTextConfig`, *optional*): + Custom text configuration or dictionary with text configuration values. + + Example: + + ```python + >>> from transformers import LightOnOCRConfig, LightOnOCRForConditionalGeneration + + >>> # Initializing a LightOnOCR configuration + >>> configuration = LightOnOCRConfig() + + >>> # Initializing a model from the configuration + >>> model = LightOnOCRForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightonocr" + sub_configs = {"text_config": LightOnOCRTextConfig, "vision_config": LightOnOCRVisionConfig} + + def __init__( + self, + spatial_merge_size: int = 2, + image_token_id: int = 151655, + vision_config: Optional[dict[str, Any]] = None, + text_config: Optional[dict[str, Any]] = None, + **kwargs, + ): + self.spatial_merge_size = spatial_merge_size + self.image_token_id = image_token_id + + if vision_config is None: + self.vision_config = LightOnOCRVisionConfig( + attention_dropout=0, + head_dim=64, + hidden_act="silu", + hidden_size=1024, + image_size=1540, + initializer_range=0.02, + intermediate_size=4096, + model_type="pixtral", + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + rope_theta=10000, + ) + elif isinstance(vision_config, PretrainedConfig): + self.vision_config = vision_config + else: + self.vision_config = LightOnOCRVisionConfig(**vision_config) + + if text_config is None: + self.text_config = LightOnOCRTextConfig( + attention_dropout=0, + head_dim=128, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=3072, + max_position_embeddings=40960, + model_type="qwen3", + num_attention_heads=16, + num_hidden_layers=28, + num_key_value_heads=8, + rms_norm_eps=1e-6, + rope_theta=1000000, + sliding_window=None, + use_cache=True, + use_sliding_window=False, + vocab_size=151936, + ) + elif isinstance(text_config, PretrainedConfig): + self.text_config = text_config + else: + self.text_config = LightOnOCRTextConfig(**text_config) + + super().__init__(**kwargs) + + +class LightOnOCRProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + +class LightOnOCRProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 14, + spatial_merge_size: int = 2, + chat_template=None, + **kwargs, + ): + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + # Calculate effective patch size for image processing + self.effective_patch_size = patch_size * spatial_merge_size + + # Get special tokens from tokenizer attributes + # These should be set on the tokenizer before creating the processor + self.image_token = getattr(tokenizer, "image_token", "<|image_pad|>") + self.image_break_token = getattr(tokenizer, "image_break_token", "<|vision_pad|>") + self.image_end_token = getattr(tokenizer, "image_end_token", "<|vision_end|>") + + # Get token IDs from tokenizer special attributes or convert from token strings + if hasattr(tokenizer, "image_token_id"): + self.image_token_id = tokenizer.image_token_id + else: + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + + if hasattr(tokenizer, "image_break_token_id"): + self.image_break_token_id = tokenizer.image_break_token_id + else: + self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token) + + if hasattr(tokenizer, "image_end_token_id"): + self.image_end_token_id = tokenizer.image_end_token_id + else: + self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token) + + self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id] + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[LightOnOCRProcessorKwargs], + ) -> BatchFeature: + if images is None and text is None: + raise ValueError("You must provide either text or images") + output_kwargs = self._merge_kwargs( + LightOnOCRProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + # Like pixtral + output_kwargs["images_kwargs"]["patch_size"] = self.effective_patch_size + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + if image_inputs.get("pixel_values") is not None: + image_sizes_iter = iter(image_inputs["image_sizes"]) + prompt_strings = [] + + for sample in text: + replace_strings = [] + + while self.image_token in sample: + image_height, image_width = next(image_sizes_iter) + num_height_tokens = image_height // self.effective_patch_size + num_width_tokens = image_width // self.effective_patch_size + num_patches = num_height_tokens * num_width_tokens + + replace_str = self.image_token * num_patches + replace_strings.append(replace_str) + + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + + prompt_strings.append(sample) + else: + prompt_strings = text + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + vision_data = {} + if image_sizes is not None: + images_kwargs = LightOnOCRProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + size = images_kwargs.get("size", None) or self.image_processor.size + + num_image_tokens = [] + for height, width in image_sizes: + resized_height, resized_width = get_resize_output_image_size( + np.zeros((height, width, 3)), + size=(size["longest_edge"], size["longest_edge"]), + patch_size=(self.effective_patch_size, self.effective_patch_size), + ) + num_height_tokens = resized_height // self.effective_patch_size + num_width_tokens = resized_width // self.effective_patch_size + num_image_tokens.append(num_width_tokens * num_height_tokens) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + +# Text model RMSNorm defined early for use in MultiModalProjector +class LightOnOCRTextRMSNorm(Qwen3RMSNorm): + pass + + +class LightOnOCRPatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches. + """ + + def __init__(self, config: LightOnOCRConfig): + super().__init__() + self.config = config + self.hidden_size = config.vision_config.hidden_size + self.spatial_merge_size = config.spatial_merge_size + + self.patch_size = config.vision_config.patch_size + + self.merging_layer = nn.Linear(self.hidden_size * self.spatial_merge_size**2, self.hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: Union[torch.Tensor, list]) -> torch.Tensor: + image_sizes_in_patches = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [patch_height * patch_width for patch_height, patch_width in image_sizes_in_patches] + hidden_dim = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # reshape image_tokens into a 2D grid + patch_height, patch_width = image_sizes_in_patches[image_index] + # shape [num_patches, hidden_dim] -> [1, hidden_dim, patch_height, patch_width] + image_grid = image_tokens.view(patch_height, patch_width, hidden_dim).permute(2, 0, 1).unsqueeze(0) + # shape [1, hidden_dim, patch_height, patch_width] -> [patch_height // sms * patch_width // sms, hidden_dim * sms**2] + # sms = spatial_merge_size + # Note: patch_height and patch_width are guaranteed to be divisible by sms because the image processor + # resizes images to multiples of effective_patch_size (patch_size * spatial_merge_size) + grid = torch.nn.functional.unfold( + image_grid, + kernel_size=self.spatial_merge_size, + stride=self.spatial_merge_size, + ) + # shape [patch_height // sms * patch_width // sms, hidden_dim * sms**2] -> [patch_height // sms * patch_width // sms, hidden_dim * sms**2] + grid = grid.view(hidden_dim * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + return image_features + + +class LightOnOCRVisionProjector(nn.Module): + def __init__(self, config: LightOnOCRConfig): + super().__init__() + self.config = config + + self.norm = LightOnOCRTextRMSNorm(config.vision_config.hidden_size, eps=1e-6) + self.patch_merger = LightOnOCRPatchMerger(config) + self.act = nn.GELU() + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=False, + ) + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False) + + def forward(self, image_features: torch.Tensor, image_sizes: Union[torch.Tensor, list]): + image_features = self.norm(image_features) + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LightOnOCRPreTrainedModel(PreTrainedModel): + config_class = LightOnOCRConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LightOnOCRVisionProjector", "LightOnOCRPatchMerger"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +# Vision model components - explicitly renamed from Pixtral +class LightOnOCRVisionPreTrainedModel(PreTrainedModel): + config_class = LightOnOCRVisionConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _no_split_modules = ["LightOnOCRVisionAttentionLayer"] + + +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def vision_eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def vision_rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def vision_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos.unsqueeze(unsqueeze_dim).to(q.device) + sin = sin.unsqueeze(unsqueeze_dim).to(q.device) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (vision_rotate_half(q) * sin) + k_embed = (k * cos) + (vision_rotate_half(k) * sin) + return q_embed, k_embed + + +class LightOnOCRAttention(PixtralAttention): + """ + Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. + """ + + def __init__(self, config): + super().__init__(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = vision_apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attention_interface: Callable = vision_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # Since we use packing, if flash_attention_2 is selected we rely on position_ids + if self.config._attn_implementation == "flash_attention_2": + kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + +@auto_docstring( + custom_intro=""" + The vision encoder of LightOnOCR, based on Pixtral vision architecture. + """ +) +class LightOnOCRVisionModel(PixtralVisionModel): + config_class = LightOnOCRVisionConfig + + +@auto_docstring( + custom_intro=""" + The language model of LightOnOCR, based on Qwen3 architecture. + """ +) +class LightOnOCRTextModel(Qwen3Model): + config_class = LightOnOCRTextConfig + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + +class LightOnOCRModel(LightOnOCRPreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: LightOnOCRConfig + + def __init__(self, config: LightOnOCRConfig): + super().__init__(config) + + self.vision_encoder = LightOnOCRVisionModel._from_config(config.vision_config) + + self.vision_projection = LightOnOCRVisionProjector(config) + + self.language_model = LightOnOCRTextModel._from_config(config.text_config) + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_image_features(self, pixel_values: torch.Tensor, image_sizes: Union[torch.Tensor, list]): + """ + Obtains image features from the vision encoder and projection. + + Args: + pixel_values: Image tensors + image_sizes: Tensor or list of (height, width) pairs for each image + + Returns: + List of image feature tensors, one per image + """ + visual_features = self.vision_encoder(pixel_values, image_sizes=image_sizes).last_hidden_state + + image_features = self.vision_projection(visual_features.squeeze(0), image_sizes) + + # Split features per image based on the effective patch size + downsample_ratio = self.config.vision_config.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features, split_sizes) + + return image_features + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + @property + def vision_model(self): + """Alias for vision_encoder to match standard composite model naming.""" + return self.vision_encoder + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_sizes: Optional[Union[torch.Tensor, list]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if inputs_embeds is None: + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + if pixel_values is not None: + # Note: image_sizes is automatically expanded by the generation framework during beam search + image_features_list = self.get_image_features(pixel_values, image_sizes) + image_features = torch.cat(image_features_list, dim=0) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = self.get_placeholder_mask(input_ids, inputs_embeds, image_features) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + outputs = self.language_model( + input_ids=None, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + **kwargs, + ) + + return outputs + + +class LightOnOCRForConditionalGeneration(LightOnOCRPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + config_class = LightOnOCRConfig + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: LightOnOCRConfig): + super().__init__(config) + self.model = LightOnOCRModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model.language_model + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_sizes: Optional[Union[torch.Tensor, list]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0] + logits: torch.Tensor = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config.text_config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = kwargs.get("image_sizes") + + return model_inputs + + @property + def language_model(self): + return self.model.language_model + + @property + def vision_encoder(self): + return self.model.vision_encoder + + @property + def vision_model(self): + """Alias for vision_encoder to match standard composite model naming.""" + return self.model.vision_encoder + + +__all__ = [ + "LightOnOCRPreTrainedModel", + "LightOnOCRVisionModel", + "LightOnOCRVisionPreTrainedModel", + "LightOnOCRTextModel", + "LightOnOCRTextPreTrainedModel", # noqa: F822 + "LightOnOCRForConditionalGeneration", + "LightOnOCRModel", + "LightOnOCRConfig", + "LightOnOCRTextConfig", + "LightOnOCRVisionConfig", + "LightOnOCRProcessor", +] diff --git a/src/transformers/models/lightonocr/processing_lightonocr.py b/src/transformers/models/lightonocr/processing_lightonocr.py new file mode 100644 index 000000000000..01abeccfd4a9 --- /dev/null +++ b/src/transformers/models/lightonocr/processing_lightonocr.py @@ -0,0 +1,236 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightonocr/modular_lightonocr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightonocr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ChannelDimension, ImageInput, get_image_size +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class LightOnOCRProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + +def _num_image_tokens(image_size: tuple[int, int], patch_size: tuple[int, int]) -> int: + """ + Calculate the number of image tokens given the image size and patch size. + + Args: + image_size (`tuple[int, int]`): + The size of the image as `(height, width)`. + patch_size (`tuple[int, int]`): + The patch size as `(height, width)`. + + Returns: + `int`: The number of image tokens. + """ + height, width = image_size + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + num_width_tokens = (width - 1) // patch_width + 1 + num_height_tokens = (height - 1) // patch_height + 1 + return num_height_tokens, num_width_tokens + + +def get_resize_output_image_size( + input_image: ImageInput, + size: Union[int, tuple[int, int], list[int], tuple[int]], + patch_size: Union[int, tuple[int, int], list[int], tuple[int]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`ImageInput`): + The image to resize. + size (`int` or `tuple[int, int]`): + Max image size an input image can be. Must be a dictionary with the key "longest_edge". + patch_size (`int` or `tuple[int, int]`): + The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` + will be used + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size) + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + height, width = get_image_size(input_image, input_data_format) + + ratio = max(height / max_height, width / max_width) + + if ratio > 1: + # Original implementation uses `round` which utilises bankers rounding, which can lead to surprising results + # Here we use floor to ensure the image is always smaller than the given "longest_edge" + height = int(math.floor(height / ratio)) + width = int(math.floor(width / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) + return num_height_tokens * patch_height, num_width_tokens * patch_width + + +class LightOnOCRProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 14, + spatial_merge_size: int = 2, + chat_template=None, + **kwargs, + ): + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + # Calculate effective patch size for image processing + self.effective_patch_size = patch_size * spatial_merge_size + + # Get special tokens from tokenizer attributes + # These should be set on the tokenizer before creating the processor + self.image_token = getattr(tokenizer, "image_token", "<|image_pad|>") + self.image_break_token = getattr(tokenizer, "image_break_token", "<|vision_pad|>") + self.image_end_token = getattr(tokenizer, "image_end_token", "<|vision_end|>") + + # Get token IDs from tokenizer special attributes or convert from token strings + if hasattr(tokenizer, "image_token_id"): + self.image_token_id = tokenizer.image_token_id + else: + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + + if hasattr(tokenizer, "image_break_token_id"): + self.image_break_token_id = tokenizer.image_break_token_id + else: + self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token) + + if hasattr(tokenizer, "image_end_token_id"): + self.image_end_token_id = tokenizer.image_end_token_id + else: + self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token) + + self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id] + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + **kwargs: Unpack[LightOnOCRProcessorKwargs], + ) -> BatchFeature: + if images is None and text is None: + raise ValueError("You must provide either text or images") + output_kwargs = self._merge_kwargs( + LightOnOCRProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + # Like pixtral + output_kwargs["images_kwargs"]["patch_size"] = self.effective_patch_size + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + if image_inputs.get("pixel_values") is not None: + image_sizes_iter = iter(image_inputs["image_sizes"]) + prompt_strings = [] + + for sample in text: + replace_strings = [] + + while self.image_token in sample: + image_height, image_width = next(image_sizes_iter) + num_height_tokens = image_height // self.effective_patch_size + num_width_tokens = image_width // self.effective_patch_size + num_patches = num_height_tokens * num_width_tokens + + replace_str = self.image_token * num_patches + replace_strings.append(replace_str) + + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + + prompt_strings.append(sample) + else: + prompt_strings = text + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + vision_data = {} + if image_sizes is not None: + images_kwargs = LightOnOCRProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + size = images_kwargs.get("size", None) or self.image_processor.size + + num_image_tokens = [] + for height, width in image_sizes: + resized_height, resized_width = get_resize_output_image_size( + np.zeros((height, width, 3)), + size=(size["longest_edge"], size["longest_edge"]), + patch_size=(self.effective_patch_size, self.effective_patch_size), + ) + num_height_tokens = resized_height // self.effective_patch_size + num_width_tokens = resized_width // self.effective_patch_size + num_image_tokens.append(num_width_tokens * num_height_tokens) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + +__all__ = ["LightOnOCRProcessor"] diff --git a/tests/models/lightonocr/__init__.py b/tests/models/lightonocr/__init__.py new file mode 100644 index 000000000000..8568c82be1c6 --- /dev/null +++ b/tests/models/lightonocr/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/models/lightonocr/test_modeling_lightonocr.py b/tests/models/lightonocr/test_modeling_lightonocr.py new file mode 100644 index 000000000000..193cac8b151c --- /dev/null +++ b/tests/models/lightonocr/test_modeling_lightonocr.py @@ -0,0 +1,676 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch LightOnOCR model.""" + +import copy +import unittest +from difflib import SequenceMatcher + +from transformers import ( + AutoProcessor, + LightOnOCRConfig, + LightOnOCRForConditionalGeneration, + LightOnOCRModel, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from transformers.image_utils import load_image + + +class LightOnOCRVisionText2TextModelTester: + def __init__( + self, + parent, + image_token_index=10, + spatial_merge_size=2, + seq_length=7, + text_config={ + "model_type": "lightonocr_text", + "seq_length": 7, + "is_training": True, + "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "intermediate_size": 37, + "hidden_act": "silu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 1, + "bos_token_id": 0, + "eos_token_id": 2, + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "attention_bias": False, + "attention_dropout": 0.0, + "head_dim": 8, + }, + is_training=True, + vision_config={ + "image_size": 112, + "patch_size": 14, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "attention_dropout": 0.0, + "hidden_act": "silu", + "initializer_range": 0.02, + "rope_theta": 10000.0, + }, + ): + self.parent = parent + self.image_token_index = image_token_index + self.spatial_merge_size = spatial_merge_size + self.text_config = text_config + self.vision_config = vision_config + self.pad_token_id = text_config["pad_token_id"] + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = 3 + # Image size must be divisible by patch_size + self.image_size = vision_config["image_size"] + self.patch_size = vision_config["patch_size"] + # Number of patches after patch conv + num_patches = (self.image_size // self.patch_size) ** 2 + # After spatial merging, number of tokens is reduced by spatial_merge_size**2 + self.num_image_tokens = num_patches // (self.spatial_merge_size**2) + self.seq_length = seq_length + self.num_image_tokens + self.encoder_seq_length = self.seq_length + + def get_config(self): + return LightOnOCRConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_index, + spatial_merge_size=self.spatial_merge_size, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + + # Avoid placing image tokens on positions that would be the pad token + input_ids[input_ids == config.image_token_id] = self.pad_token_id + + # Place image tokens at the beginning + input_ids[:, : self.num_image_tokens] = config.image_token_id + + attention_mask = input_ids.ne(self.pad_token_id) + + # Create image_sizes as tensor - must match batch size + image_sizes = torch.tensor([[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long) + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + } + return config, inputs_dict + + def prepare_config_and_inputs_for_generate(self, batch_size=None): + """Prepare config and inputs for generation tests.""" + if batch_size is None: + batch_size = self.batch_size + + # Get base config + config = self.get_config() + + # Create pixel_values with the specified batch size + pixel_values = floats_tensor( + [ + batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + + # Create input_ids + input_ids = ids_tensor([batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + + # Avoid placing image tokens on positions that would be the pad token + input_ids[input_ids == config.image_token_id] = self.pad_token_id + + # Place image tokens at the beginning + input_ids[:, : self.num_image_tokens] = config.image_token_id + + attention_mask = input_ids.ne(self.pad_token_id) + + # Create image_sizes as tensor - must match batch size + image_sizes = torch.tensor([[self.image_size, self.image_size]] * batch_size, dtype=torch.long) + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + } + return config, inputs_dict + + +@require_torch +class LightOnOCRForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `LightOnOCRForConditionalGeneration`. + """ + + all_model_classes = ( + ( + LightOnOCRModel, + LightOnOCRForConditionalGeneration, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = {"image-text-to-text": LightOnOCRForConditionalGeneration} if is_torch_available() else {} + + _is_composite = True + test_head_masking = False + test_pruning = False + test_torchscript = False + + def setUp(self): + self.model_tester = LightOnOCRVisionText2TextModelTester(self) + common_properties = ["image_token_id", "spatial_merge_size"] + self.config_tester = ConfigTester( + self, config_class=LightOnOCRConfig, has_text_modality=False, common_properties=common_properties + ) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + """ + Prepare inputs for the model class, ensuring image_sizes matches the batch size. + """ + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + # Ensure image_sizes matches the batch size of pixel_values or input_ids + if "pixel_values" in inputs_dict and "image_sizes" in inputs_dict: + batch_size = inputs_dict["pixel_values"].shape[0] + # If image_sizes doesn't match batch size, adjust it + if len(inputs_dict["image_sizes"]) != batch_size: + # Take only the first batch_size entries + inputs_dict["image_sizes"] = inputs_dict["image_sizes"][:batch_size] + + return inputs_dict + + def prepare_config_and_inputs_for_generate(self, batch_size=1): + """Override to use the model_tester's custom method.""" + return self.model_tester.prepare_config_and_inputs_for_generate(batch_size=batch_size) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs throw an error with explicit message saying what is wrong + when number of images doesn't match number of image tokens in the text. + Also we need to test multi-image cases when one prompt has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + curr_input_dict = copy.deepcopy(input_dict) # in-place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications + + # remove one image but leave the image token in text + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] + curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] + image_sizes = curr_input_dict["image_sizes"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_sizes = torch.cat([image_sizes, image_sizes], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + + def test_spatial_merge_size(self): + """ + Test that models can be created and initialized with different spatial_merge_size values. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Test that model can be created with different spatial_merge_size values + for spatial_merge_size in [1, 2, 4]: + curr_config = copy.deepcopy(config) + curr_config.spatial_merge_size = spatial_merge_size + + for model_class in self.all_model_classes: + # Build model with the new config - should not raise any errors + model = model_class(curr_config).to(torch_device) + model.eval() + + # Verify the spatial_merge_size is set correctly + self.assertEqual(model.config.spatial_merge_size, spatial_merge_size) + + # Verify the model has the expected components + if hasattr(model, "model"): + self.assertTrue(hasattr(model.model, "vision_projection")) + self.assertEqual(model.model.vision_projection.config.spatial_merge_size, spatial_merge_size) + elif hasattr(model, "vision_projection"): + self.assertEqual(model.vision_projection.config.spatial_merge_size, spatial_merge_size) + + def test_forward_pass_with_image_sizes(self): + """ + Test that the model correctly handles variable image sizes. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + # Test with different image sizes in the same batch + batch_size = 2 + pixel_values = floats_tensor( + [batch_size, 3, self.model_tester.image_size, self.model_tester.image_size] + ).to(torch_device) + + # Different image sizes (but still need to be divisible by patch_size) + image_sizes = torch.tensor( + [[self.model_tester.image_size, self.model_tester.image_size]] * batch_size, + dtype=torch.long, + device=torch_device, + ) + + num_patches = (self.model_tester.image_size // self.model_tester.patch_size) ** 2 + num_image_tokens = num_patches // (config.spatial_merge_size**2) + + input_ids = ids_tensor([batch_size, 10 + num_image_tokens], config.text_config.vocab_size - 1) + 1 + # Ensure no tokens accidentally equal image_token_id + input_ids[input_ids == config.image_token_id] = config.image_token_id + 1 + # Now place image tokens at the beginning + input_ids[:, :num_image_tokens] = config.image_token_id + input_ids = input_ids.to(torch_device) + + outputs = model( + pixel_values=pixel_values, + input_ids=input_ids, + image_sizes=image_sizes, + ) + + self.assertIsNotNone(outputs) + + def test_model_outputs_equivalence(self): + """ + Test that model outputs are consistent across different input configurations. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + with torch.no_grad(): + outputs1 = model(**input_dict) + outputs2 = model(**input_dict) + + # Check that outputs are deterministic + if hasattr(outputs1, "last_hidden_state") and hasattr(outputs2, "last_hidden_state"): + self.assertTrue(torch.allclose(outputs1.last_hidden_state, outputs2.last_hidden_state, atol=1e-5)) + + @unittest.skip( + "LightOnOCR uses complex attention patterns with sliding windows, skipping gradient checkpointing test" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + "LightOnOCR uses complex attention patterns with sliding windows, skipping gradient checkpointing test" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + "LightOnOCR uses complex attention patterns with sliding windows, skipping gradient checkpointing test" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_eager_matches_sdpa_generate(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_from_config(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_inference_equivalence(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip("Pixtral does not support attention interfaces.") + def test_flex_attention_with_grads(self): + pass + + def test_initialization(self): + """ + Test that model initializes correctly with proper weight initialization. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # Check that model has all expected components + if model_class == LightOnOCRForConditionalGeneration: + self.assertTrue(hasattr(model, "model")) + self.assertTrue(hasattr(model.model, "vision_encoder")) + self.assertTrue(hasattr(model.model, "language_model")) + self.assertTrue(hasattr(model.model, "vision_projection")) + self.assertTrue(hasattr(model, "lm_head")) + elif model_class == LightOnOCRModel: + self.assertTrue(hasattr(model, "vision_encoder")) + self.assertTrue(hasattr(model, "language_model")) + self.assertTrue(hasattr(model, "vision_projection")) + + def test_vision_projection(self): + """ + Test that the vision projection correctly transforms vision embeddings to text space. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + model = LightOnOCRModel(config).to(torch_device) + model.eval() + + # Convert image_sizes to list for vision_encoder + if isinstance(input_dict["image_sizes"], torch.Tensor): + image_sizes_list = [(int(h), int(w)) for h, w in input_dict["image_sizes"]] + else: + image_sizes_list = input_dict["image_sizes"] + + with torch.no_grad(): + # Get vision features + vision_outputs = model.vision_encoder( + pixel_values=input_dict["pixel_values"].to(torch_device), + image_sizes=image_sizes_list, + ) + + # Project vision features + projected = model.vision_projection( + vision_outputs.last_hidden_state.squeeze(0), + image_sizes_list, + ) + + # Check output dimensions - should match text hidden size + self.assertEqual(projected.shape[-1], config.text_config.hidden_size) + + def test_get_image_features(self): + """ + Test the get_image_features method. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + model = LightOnOCRModel(config).to(torch_device) + model.eval() + + with torch.no_grad(): + image_features_list = model.get_image_features( + pixel_values=input_dict["pixel_values"].to(torch_device), + image_sizes=input_dict["image_sizes"], + ) + + # Check that features are returned as a list + self.assertIsNotNone(image_features_list) + self.assertIsInstance(image_features_list, (list, tuple)) + + # Concatenate features and check shape + image_features = torch.cat(image_features_list, dim=0) + self.assertEqual(image_features.shape[-1], config.text_config.hidden_size) + + +@slow +@require_torch +class LightOnOCRForConditionalGenerationIntegrationTest(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_lightonocr_ocr_integration(self): + """ + Integration test for LightOnOCR OCR capabilities. + Tests that the model can perform OCR on a real image and produce expected output. + + """ + + model_id = "lightonai/LightOnOCR-1B-1025" + + # Load processor and model from Hub + processor = AutoProcessor.from_pretrained(model_id) + model = LightOnOCRForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + model.eval() + + # Load a test OCR image + # This is a standard OCR test image from HuggingFace fixtures + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/SROIE-receipt.jpeg" + ) + + # Process image and prepare inputs + # Using chat template as shown in the model's usage pattern + chat = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image}, + ], + } + ] + + inputs = processor.apply_chat_template( + chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device) + + # Generate OCR output + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=50, + do_sample=False, + num_beams=1, + ) + + # Decode output, excluding the input prompt + decoded_output = processor.decode(generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + + expected_output = "Document No : TD01167104\n\nDate : 25/12/2018 8:13:39 PM\n\nCashier : MANIS\n\nMember :\n\nCASH BILL\n\n| CODE" + + similarity = SequenceMatcher(None, decoded_output, expected_output).ratio() + + # Require at least 95% similarity to catch regressions while allowing minor variations + self.assertGreater( + similarity, + 0.95, + f"Model output differs too much from expected output (similarity: {similarity:.2%}).\n" + f"Expected:\n{expected_output}\n\nGot:\n{decoded_output}", + ) + + def test_model_can_generate_without_images(self): + """ + Test that the model can generate text without image inputs. + """ + # Create a small config for fast testing + text_config = { + "vocab_size": 100, + "hidden_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "intermediate_size": 128, + "max_position_embeddings": 512, + "rms_norm_eps": 1e-6, + "head_dim": 16, + } + vision_config = { + "hidden_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 128, + "image_size": 112, + "patch_size": 14, + } + + config = LightOnOCRConfig(text_config=text_config, vision_config=vision_config, image_token_id=10) + model = LightOnOCRForConditionalGeneration(config).to(torch_device) + model.eval() + + # Create text-only input + input_ids = torch.randint(0, config.text_config.vocab_size - 1, (1, 10), device=torch_device) + 1 + + with torch.no_grad(): + outputs = model.generate(input_ids=input_ids, max_new_tokens=5) + + self.assertIsNotNone(outputs) + self.assertEqual(outputs.shape[0], 1) + self.assertGreater(outputs.shape[1], input_ids.shape[1]) + + def test_model_forward_with_images(self): + """ + Test forward pass with image inputs. + """ + text_config = { + "vocab_size": 100, + "hidden_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "intermediate_size": 128, + "max_position_embeddings": 512, + "rms_norm_eps": 1e-6, + "head_dim": 16, + } + vision_config = { + "hidden_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 128, + "image_size": 112, + "patch_size": 14, + } + + config = LightOnOCRConfig(text_config=text_config, vision_config=vision_config, image_token_id=10) + model = LightOnOCRForConditionalGeneration(config).to(torch_device) + model.eval() + + # Create inputs + batch_size = 2 + image_size = 112 + pixel_values = torch.randn(batch_size, 3, image_size, image_size, device=torch_device) + image_sizes = torch.tensor([[image_size, image_size]] * batch_size, dtype=torch.long, device=torch_device) + + # Calculate number of image tokens + num_patches = (image_size // 14) ** 2 # patch_size = 14 + num_image_tokens = num_patches // (config.spatial_merge_size**2) + + seq_len = num_image_tokens + 10 + input_ids = torch.randint(0, config.text_config.vocab_size - 1, (batch_size, seq_len), device=torch_device) + 1 + # Ensure no tokens accidentally equal image_token_id + input_ids[input_ids == config.image_token_id] = config.image_token_id + 1 + # Now place image tokens at the beginning + input_ids[:, :num_image_tokens] = config.image_token_id + + with torch.no_grad(): + outputs = model( + pixel_values=pixel_values, + input_ids=input_ids, + image_sizes=image_sizes, + ) + + self.assertIsNotNone(outputs) + self.assertIsNotNone(outputs.logits) + self.assertEqual(outputs.logits.shape[0], batch_size) + self.assertEqual(outputs.logits.shape[1], seq_len) + self.assertEqual(outputs.logits.shape[2], config.text_config.vocab_size) diff --git a/tests/models/lightonocr/test_processor_lightonocr.py b/tests/models/lightonocr/test_processor_lightonocr.py new file mode 100644 index 000000000000..33d3b725cd0c --- /dev/null +++ b/tests/models/lightonocr/test_processor_lightonocr.py @@ -0,0 +1,263 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image + + from transformers import LightOnOCRProcessor + +if is_torch_available(): + import torch + + +@require_vision +@require_torch +class LightOnOCRProcessorTest(ProcessorTesterMixin, unittest.TestCase): + """Test suite for LightOnOCR processor.""" + + processor_class = LightOnOCRProcessor + + def setUp(self): + """Set up test fixtures.""" + self.tmpdirname = tempfile.mkdtemp() + + # Create a Pixtral image processor (LightOnOCR uses Pixtral vision architecture) + image_processor = AutoImageProcessor.from_pretrained( + "mistral-community/pixtral-12b", size={"longest_edge": 1024} + ) + + # Create a tokenizer (using Qwen2 as base) + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + # Add special tokens for LightOnOCR + special_tokens_dict = { + "additional_special_tokens": [ + "<|image_pad|>", + "<|vision_pad|>", + "<|vision_end|>", + ] + } + tokenizer.add_special_tokens(special_tokens_dict) + + # Set special token attributes on the tokenizer for multimodal processing + tokenizer.image_token = "<|image_pad|>" + tokenizer.image_break_token = "<|vision_pad|>" + tokenizer.image_end_token = "<|vision_end|>" + tokenizer.image_token_id = tokenizer.convert_tokens_to_ids(tokenizer.image_token) + tokenizer.image_break_token_id = tokenizer.convert_tokens_to_ids(tokenizer.image_break_token) + tokenizer.image_end_token_id = tokenizer.convert_tokens_to_ids(tokenizer.image_end_token) + + # Add a basic multimodal-aware chat template to the tokenizer + # This template extracts text from the multimodal content format + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if loop.first and messages[0]['role'] != 'system' %}" + "{{ '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}" + "{% endif %}" + "{{'<|im_start|>' + message['role'] + '\\n' }}" + "{% if message['content'] is string %}" + "{{ message['content'] }}" + "{% else %}" + "{% for content in message['content'] %}" + "{% if content['type'] == 'text' %}" + "{{ content['text'] }}" + "{% elif content['type'] == 'image' %}" + "{{ '<|image_pad|>' }}" + "{% endif %}" + "{% endfor %}" + "{% endif %}" + "{{ '<|im_end|>\\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" + ) + + # Create and save processor + processor = LightOnOCRProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + patch_size=14, + spatial_merge_size=2, + ) + processor.save_pretrained(self.tmpdirname) + + self.image_token = processor.image_token + + def tearDown(self): + """Clean up after tests.""" + shutil.rmtree(self.tmpdirname, ignore_errors=True) + + def get_tokenizer(self, **kwargs): + """Get tokenizer from saved processor.""" + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + """Get image processor from saved processor.""" + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def get_processor(self, **kwargs): + """Get processor from saved directory.""" + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + def prepare_image_inputs(self, batch_size=None): + """Prepare small dummy image inputs.""" + image = Image.new("RGB", (112, 112), color="red") + if batch_size is None: + return image + return [image] * batch_size + + def test_processor_creation(self): + """Test that processor can be created and loaded.""" + processor = self.get_processor() + self.assertIsInstance(processor, LightOnOCRProcessor) + self.assertIsNotNone(processor.tokenizer) + self.assertIsNotNone(processor.image_processor) + + def test_processor_with_text_only(self): + """Test processor with text input only.""" + processor = self.get_processor() + text = "This is a test sentence." + + inputs = processor(text=text, return_tensors="pt") + + self.assertIn("input_ids", inputs) + self.assertIn("attention_mask", inputs) + self.assertEqual(inputs["input_ids"].shape[0], 1) # batch size + + def test_processor_with_image_and_text(self): + """Test processor with both image and text inputs.""" + processor = self.get_processor() + image = self.prepare_image_inputs() + text = f"{self.image_token} Extract text from this image." + + inputs = processor(images=image, text=text, return_tensors="pt") + + self.assertIn("input_ids", inputs) + self.assertIn("attention_mask", inputs) + self.assertIn("pixel_values", inputs) + self.assertIn("image_sizes", inputs) + + # Check shapes + self.assertEqual(inputs["input_ids"].shape[0], 1) # batch size + self.assertEqual(len(inputs["pixel_values"].shape), 4) # (batch, channels, height, width) + self.assertEqual(len(inputs["image_sizes"]), 1) # one image + + def test_processor_image_token_expansion(self): + """Test that image token is properly expanded based on image size.""" + processor = self.get_processor() + image = self.prepare_image_inputs() + text = f"{self.image_token} Describe this image." + + inputs = processor(images=image, text=text, return_tensors="pt") + + # The image token should be expanded to multiple tokens based on patch size + # Count occurrences of image_token_id in input_ids + image_token_id = processor.image_token_id + num_image_tokens = (inputs["input_ids"] == image_token_id).sum().item() + + # Should have multiple image tokens (one per patch after spatial merging) + self.assertGreater(num_image_tokens, 1) + + def test_processor_batch_processing(self): + """Test processor with batch of inputs.""" + processor = self.get_processor() + images = self.prepare_image_inputs(batch_size=2) + texts = [f"{self.image_token} Extract text." for _ in range(2)] + + inputs = processor(images=images, text=texts, return_tensors="pt", padding=True) + + self.assertEqual(inputs["input_ids"].shape[0], 2) # batch size + self.assertEqual(inputs["pixel_values"].shape[0], 2) # two images + + def test_processor_model_input_names(self): + """Test that processor returns correct model input names.""" + processor = self.get_processor() + + expected_keys = {"input_ids", "attention_mask", "pixel_values", "image_sizes"} + model_input_names = set(processor.model_input_names) + + # Check that all expected keys are in model_input_names + for key in expected_keys: + self.assertIn(key, model_input_names) + + def test_processor_without_images(self): + """Test that processor handles text-only input correctly.""" + processor = self.get_processor() + text = "This is text without any images." + + inputs = processor(text=text, return_tensors="pt") + + self.assertIn("input_ids", inputs) + self.assertIn("attention_mask", inputs) + self.assertNotIn("pixel_values", inputs) + self.assertNotIn("image_sizes", inputs) + + def test_processor_special_tokens(self): + """Test that special tokens are properly registered.""" + processor = self.get_processor() + + # Check that image tokens are properly defined + self.assertEqual(processor.image_token, "<|image_pad|>") + self.assertEqual(processor.image_break_token, "<|vision_pad|>") + self.assertEqual(processor.image_end_token, "<|vision_end|>") + + # Check that tokens have valid IDs + self.assertIsInstance(processor.image_token_id, int) + self.assertIsInstance(processor.image_break_token_id, int) + self.assertIsInstance(processor.image_end_token_id, int) + + def test_processor_return_types(self): + """Test different return types (pt, np, list).""" + processor = self.get_processor() + image = self.prepare_image_inputs() + text = f"{self.image_token} Test image." + + # Test PyTorch tensors + inputs_pt = processor(images=image, text=text, return_tensors="pt") + self.assertIsInstance(inputs_pt["input_ids"], torch.Tensor) + + # Test NumPy arrays + inputs_np = processor(images=image, text=text, return_tensors="np") + self.assertIsInstance(inputs_np["input_ids"], np.ndarray) + + # Test lists + inputs_list = processor(images=image, text=text, return_tensors=None) + self.assertIsInstance(inputs_list["input_ids"], list) + + def test_image_sizes_output(self): + """Test that image_sizes are correctly computed.""" + processor = self.get_processor() + image = Image.new("RGB", (300, 400), color="blue") # Different size + text = f"{self.image_token} Test." + + inputs = processor(images=image, text=text, return_tensors="pt") + + self.assertIn("image_sizes", inputs) + self.assertEqual(len(inputs["image_sizes"]), 1) + # Image size should be a tuple of (height, width) + self.assertEqual(len(inputs["image_sizes"][0]), 2) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 1864e928b752..1f9cc213dc31 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -62,6 +62,10 @@ "Qwen2_5OmniTalkerConfig": ["use_sliding_window", "max_window_layers"], "Qwen3Config": ["max_window_layers", "use_sliding_window"], # now use `layer_types` instead "Qwen3MoeConfig": ["max_window_layers", "use_sliding_window"], + "LightOnOCRTextConfig": [ + "use_sliding_window", + "max_window_layers", + ], # inherited from Qwen3Config, now use `layer_types` instead # `cache_implementation` should be in the default generation config, but we don't yet support per-model # generation configs (TODO joao) "Gemma2Config": ["tie_word_embeddings", "cache_implementation"], diff --git a/utils/check_repo.py b/utils/check_repo.py index 58ff56484f27..1689413d9379 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -98,6 +98,8 @@ "Phi4MultimodalVisionModel", "Glm4vVisionModel", "Glm4vMoeVisionModel", + "LightOnOCRTextModel", + "LightOnOCRVisionModel", "EvollaSaProtPreTrainedModel", "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. @@ -187,6 +189,8 @@ "Qwen2_5_VLTextModel", # Building part of bigger (tested) model "InternVLVisionModel", # Building part of bigger (tested) model "JanusVisionModel", # Building part of bigger (tested) model + "LightOnOCRText", # Building part of bigger (tested) model. Tested implicitly through LightOnOCRForConditionalGeneration. + "LightOnOCRVision", # Building part of bigger (tested) model. Tested implicitly through LightOnOCRForConditionalGeneration. "TimesFmModel", # Building part of bigger (tested) model "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.