Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def is_local_dist_rank_0():
"qwen2_5_vl",
"videollava",
"vipllava",
"paddleocrvl",
]


Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@
from .ovis2 import *
from .owlv2 import *
from .owlvit import *
from .paddleocr_vl import *
from .paligemma import *
from .parakeet import *
from .patchtsmixer import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@
("ovis2", "Ovis2Config"),
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
("paddleocr_vl", "PaddleOCRVLConfig"),
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
Expand Down Expand Up @@ -761,6 +762,7 @@
("ovis2", "Ovis2"),
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
("paddleocr_vl", "PaddleOCRVL"),
("paligemma", "PaliGemma"),
("parakeet", "Parakeet"),
("parakeet_ctc", "Parakeet"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("perception_lm", (None, "PerceptionLMImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("ovis2", "Ovis2ForConditionalGeneration"),
("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("perception_lm", "PerceptionLMForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
("ovis2", "Ovis2Processor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paddleocr_vl", "PaddleOCRVLProcessor"),
("paligemma", "PaliGemmaProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@
("ovis2", (None, "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paddleocr_vl", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("parakeet", (None, "ParakeetTokenizerFast" if is_tokenizers_available() else None)),
(
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/paddleocr_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_paddleocr_vl import *
from .modeling_paddleocr_vl import *
from .processing_paddleocr_vl import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
293 changes: 293 additions & 0 deletions src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_paddleocr_vl.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from ...configuration_utils import PreTrainedConfig, PretrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation


class PaddleOCRVLVisionConfig(PretrainedConfig):
model_type = "paddleocr_vl_vision"
base_config_key = "vision_config"

def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=14,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
spatial_merge_size=2,
temporal_patch_size=2,
tokens_per_second=2,
**kwargs,
):
super().__init__(**kwargs)

self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second


class PaddleOCRVLTextConfig(PretrainedConfig):
"""
Configuration class.

This class stores the configuration of an Ernie model, defining the model architecture.
It inherits from PretrainedConfig and can be used to control model outputs.
"""

model_type = "paddleocr_vl_text"

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=11008,
max_position_embeddings=32768,
num_hidden_layers=2,
num_attention_heads=2,
rms_norm_eps=1e-6,
use_cache=False,
use_flash_attention=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
head_dim=128,
hidden_act="silu",
use_bias=False,
rope_theta=10000,
weight_share_add_bias=True,
ignored_index=-100,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
compression_ratio: float = 1.0,
num_key_value_heads=None,
max_sequence_length=None,
tie_word_embeddings=False,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
**kwargs,
):
"""
Initialize configuration with default or specified parameters.

Args:
vocab_size (int): Size of the vocabulary (number of unique tokens)
hidden_size (int): Dimensionality of the encoder layers and the pooler layer
intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
max_position_embeddings (int): Maximum sequence length the model can handle
num_hidden_layers (int): Number of hidden layers in the Transformer encoder
num_attention_heads (int): Number of attention heads for each attention layer
rms_norm_eps (float): The epsilon used by the RMS normalization layers
use_cache (bool): Whether to use caching for faster generation (decoding)
use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
pad_token_id (int): Token ID used for padding sequences
bos_token_id (int): Token ID used for beginning-of-sequence
eos_token_id (int): Token ID used for end-of-sequence
use_bias (bool): Whether to use bias terms in linear layers
rope_theta (float): The base period of the RoPE embeddings
weight_share_add_bias (bool): Whether to share bias weights in certain layers
ignored_index (int): Target value that is ignored during loss computation
attention_probs_dropout_prob (float): Dropout probability for attention weights
hidden_dropout_prob (float): Dropout probability for hidden layers
compression_ratio (float): Ratio for KV cache compression (1.0 = no compression)
num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
max_sequence_length (int): Maximum sequence length for positional embeddings
**kwargs: Additional keyword arguments passed to parent class
"""

# Set default for tied embeddings if not specified.
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.use_flash_attention = use_flash_attention
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.head_dim = head_dim
self.hidden_act = hidden_act
self.sliding_window = None
self.hidden_size = hidden_size
self.use_bias = use_bias
self.weight_share_add_bias = weight_share_add_bias
self.rope_theta = rope_theta
self.ignored_index = ignored_index
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_dropout_prob = hidden_dropout_prob
self.compression_ratio = compression_ratio
self.num_key_value_heads = num_key_value_heads
self.max_sequence_length = max_sequence_length
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or rope_parameters
if self.rope_parameters is not None and self.rope_parameters["rope_type"] == "mrope":
self.rope_parameters["rope_type"] = "default"
rope_config_validation(self, ignore_keys={"mrope_section"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


class PaddleOCRVLConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaddleOCRVLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).

Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.


Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PaddleOCRVLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The token index to denote start of vision input.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The token index to denote end of vision input.

```python
>>> from transformers import PaddleOCRVLForConditionalGeneration, PaddleOCRVLConfig

>>> # Initializing a PaddleOCRVL style configuration
>>> configuration = PaddleOCRVLConfig()

>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = PaddleOCRVLForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "paddleocr_vl"
sub_configs = {"vision_config": PaddleOCRVLVisionConfig, "text_config": PaddleOCRVLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
**kwargs,
):
# We need to init super() here so that it does not reset values
# that are in text config to the BaseClass defaults. The Base
# config has many text related defaults and not all defaults are same as for `PaddleOCRVLTextConfig`
super().__init__(**kwargs)

if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()

if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)

self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id

# Attention implementation to use. It sets it recursively on sub-configs so we call it again in the end
self._attn_implementation = kwargs.pop("attn_implementation", None)

def __setattr__(self, key, value):
if (
(text_config := super().__getattribute__("__dict__").get("text_config")) is not None
and key not in ["_name_or_path", "model_type", "dtype", "_attn_implementation_internal"]
and key in text_config.__dict__
):
setattr(text_config, key, value)
else:
super().__setattr__(key, value)

def __getattribute__(self, key):
if "text_config" in super().__getattribute__("__dict__") and key not in [
"_name_or_path",
"model_type",
"dtype",
"_attn_implementation_internal",
]:
text_config = super().__getattribute__("text_config")
if key in text_config.__dict__:
return getattr(text_config, key)

return super().__getattribute__(key)


__all__ = ["PaddleOCRVLConfig", "PaddleOCRVLTextConfig"]
Loading
Loading