Skip to content

Conversation

@zhang-prog
Copy link

@zhang-prog zhang-prog commented Nov 13, 2025

What does this PR do?

This PR adds PaddleOCR-VL model to Hugging Face Transformers from PaddleOCR.

Relevant Links:

PaddleOCR
https://huggingface.co/PaddlePaddle/PaddleOCR-VL

Usage

Use a pipeline

from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
result = pipe(text=messages)
print(result)

Load model directly

from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@zucchini-nlp zucchini-nlp self-requested a review November 13, 2025 09:07
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @zhang-prog , thanks for the PR! Great model to have in transformers!

The main thing to fix first is the naming, it should clearly include "PaddlePaddleOCR" and follow the usual pattern depending on the modality. The config format also isn’t right; it needs to be fully nested, with text and vision configs inside. Additionally there are no tests or docs, several files are missing. You can run transformers add-new-model-like which would generate a placeholder with the necessary files. I also left some smaller comments here and there. Let me know if you hit any issues

Comment on lines +91 to +98
if height < factor:
width = round((width * factor) / height)
height = factor

if width < factor:
height = round((height * factor) / width)
width = factor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as Qwen but with support for H/W smaller than a factor. I think we made qwen-VL support small images as well, so prob directly importing will give expected result?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike Qwen-VL's smart_resize method, we have differences in factor and max_pixels that need to be preserved. Forcing the use of Qwen-VL's smart_resize method could result in decreased accuracy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the default values of min-max pixels? We're passing the correct value from within image processor when calling smart_resize, so it shouldn't be a problem, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our method is not entirely identical to Qwen’s, therefore direct reusability is not possible. The differences are as follows:

+   if height < factor:
+       width = round((width * factor) / height)
+       height = factor

+   if width < factor:
+       height = round((height * factor) / width)
+       width = factor

    if max(height, width) / min(height, width) > 200:
        raise ValueError(
            f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
        )
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
-       h_bar = math.floor(height / beta / factor) * factor
-       w_bar = math.floor(width / beta / factor) * factor
+       h_bar = max(factor, math.floor(height / beta / factor) * factor)
+       w_bar = max(factor, math.floor(width / beta / factor) * factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar

tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the token in tokenizer so we can assume t's always available?
https://huggingface.co/docs/transformers/en/main_classes/tokenizer#multimodal-tokenizer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment not addressed. After adding it in the tokenizer, we can assume it exists and get directly tokenizer.image_token

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding image_token to the special_tokens_map.json, but found it had no effect. What should I do? Do you have an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try with the above link? It should work unless smth got broken in the latest refactor

@zhang-prog
Copy link
Author

@zucchini-nlp
We have refactored the code to address the issues you mentioned in your comments.
Please review the code again when you have time.
Thank you for your efforts!!!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

Comment on lines +91 to +98
if height < factor:
width = round((width * factor) / height)
height = factor

if width < factor:
height = round((height * factor) / width)
width = factor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the default values of min-max pixels? We're passing the correct value from within image processor when calling smart_resize, so it shouldn't be a problem, no?

image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = self.projector(image_embeds, image_grid_thw)
image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) so we can support multiGPU without problems

@zhang-prog
Copy link
Author

zhang-prog commented Nov 26, 2025

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

@zucchini-nlp Thank you for your valuable insights! We’ve carefully addressed all comments and responded to your overall recommendations.

  1. We support bs > 1, like this:
from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages1 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
messages2 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
batch_messages = [messages1, messages2]
inputs = processor.apply_chat_template(
	batch_messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
    padding=True,
    padding_side='left',
).to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)
  1. We still have some issues to discuss. I replied to your comment and will generate the final version of the document once it’s completed.

  2. We also added the PaddleOCRVisionConfig and PaddleOCRTextConfig into modular.

Thank you for your efforts. ❤️
PTAL.

@zhang-prog
Copy link
Author

@zucchini-nlp How do I properly add documentation pages and unit tests? I tried to use transformers add-new-model-like, which generates the new modular_xxx.py files, but this process might not be the right approach.

@zhang-prog
Copy link
Author

@zucchini-nlp PTAL. Thanks❤️

@zucchini-nlp
Copy link
Member

Sorry, taking a look.Got lost in my notifications

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, only a few comments and replied to your questions above

For the docs and the tests, they need to be in source/docs/en/model_doc and in tests folder. You can take a look at the recently merged model for an example https://github.com/huggingface/transformers/pull/41112/files#diff-857421affc3c877bca95377cbb6adb3a8374b149fcbdcc6c759ea0408fa96897



logger = logging.get_logger(__name__)
rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had quite a few changes with rope recently. Rebasing main will help to get rid of this. Ignore keys now can be passed to super and rope validation is called in base class :)

Comment on lines +222 to +223
images = make_list_of_images(images)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to add images = self.fetch_images(images) below, otherwise chat templates will not work with slow processors

Comment on lines +286 to +291
self.temporal_patch_size,
channel,
grid_h,
self.patch_size,
grid_w,
self.patch_size,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we should use args passed by users, i.e. patch_size/temporal_patch_size not from self attributes. That way processor stays configurable at call-time

self.patch_size,
)
patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
assert self.temporal_patch_size == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here as well, if temporal_patch_size != 1: raise ValueError(informative message suggesting to not pass any other value but 1)

Comment on lines +399 to +404
self.temporal_patch_size,
channel,
grid_h,
self.patch_size,
grid_w,
self.patch_size,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here for self attributes

Comment on lines +551 to +552
self.text_config = config.text_config
self.vision_config = config.vision_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use them outside of __init__, therefore no need to assign as attributes. We can directly use to get layer dimension below

self.vision_config = config.vision_config
self.merge_kernel_size = (self.vision_config.spatial_merge_size, self.vision_config.spatial_merge_size)

self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit: same, if not used in forward call, we don't need to assign as self attr



class PaddleOCRVLModel(Qwen2VLModel):
_checkpoint_conversion_mapping = {"^model": "language_model"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need it or is the model state dict converted to correct format?

Comment on lines +977 to +983
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qwen-vl has updated building position ids recently due to expectations from PEFT, needs a rebase and update in PaddleOCR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants