Skip to content

Conversation

@zhang-prog
Copy link

@zhang-prog zhang-prog commented Nov 13, 2025

What does this PR do?

This PR adds PaddleOCR-VL model to Hugging Face Transformers from PaddleOCR.

Relevant Links:

PaddleOCR
https://huggingface.co/PaddlePaddle/PaddleOCR-VL

Usage

Use a pipeline

from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
result = pipe(text=messages)
print(result)

Load model directly

from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)

@zucchini-nlp zucchini-nlp self-requested a review November 13, 2025 09:07
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @zhang-prog , thanks for the PR! Great model to have in transformers!

The main thing to fix first is the naming, it should clearly include "PaddlePaddleOCR" and follow the usual pattern depending on the modality. The config format also isn’t right; it needs to be fully nested, with text and vision configs inside. Additionally there are no tests or docs, several files are missing. You can run transformers add-new-model-like which would generate a placeholder with the necessary files. I also left some smaller comments here and there. Let me know if you hit any issues

Comment on lines +91 to +98
if height < factor:
width = round((width * factor) / height)
height = factor

if width < factor:
height = round((height * factor) / width)
width = factor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as Qwen but with support for H/W smaller than a factor. I think we made qwen-VL support small images as well, so prob directly importing will give expected result?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike Qwen-VL's smart_resize method, we have differences in factor and max_pixels that need to be preserved. Forcing the use of Qwen-VL's smart_resize method could result in decreased accuracy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the default values of min-max pixels? We're passing the correct value from within image processor when calling smart_resize, so it shouldn't be a problem, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our method is not entirely identical to Qwen’s, therefore direct reusability is not possible. The differences are as follows:

+   if height < factor:
+       width = round((width * factor) / height)
+       height = factor

+   if width < factor:
+       height = round((height * factor) / width)
+       width = factor

    if max(height, width) / min(height, width) > 200:
        raise ValueError(
            f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
        )
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
-       h_bar = math.floor(height / beta / factor) * factor
-       w_bar = math.floor(width / beta / factor) * factor
+       h_bar = max(factor, math.floor(height / beta / factor) * factor)
+       w_bar = max(factor, math.floor(width / beta / factor) * factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar

Copy link
Author

@zhang-prog zhang-prog Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the diff is shown, we need to use our method :)

tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the token in tokenizer so we can assume t's always available?
https://huggingface.co/docs/transformers/en/main_classes/tokenizer#multimodal-tokenizer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment not addressed. After adding it in the tokenizer, we can assume it exists and get directly tokenizer.image_token

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding image_token to the special_tokens_map.json, but found it had no effect. What should I do? Do you have an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try with the above link? It should work unless smth got broken in the latest refactor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We add image_token to tokenizer_config.json to solve this problem

@zhang-prog
Copy link
Author

@zucchini-nlp
We have refactored the code to address the issues you mentioned in your comments.
Please review the code again when you have time.
Thank you for your efforts!!!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

Comment on lines +91 to +98
if height < factor:
width = round((width * factor) / height)
height = factor

if width < factor:
height = round((height * factor) / width)
width = factor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the default values of min-max pixels? We're passing the correct value from within image processor when calling smart_resize, so it shouldn't be a problem, no?

image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = self.projector(image_embeds, image_grid_thw)
image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) so we can support multiGPU without problems

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zhang-prog
Copy link
Author

zhang-prog commented Nov 26, 2025

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

@zucchini-nlp Thank you for your valuable insights! We’ve carefully addressed all comments and responded to your overall recommendations.

  1. We support bs > 1, like this:
from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages1 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
messages2 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
batch_messages = [messages1, messages2]
inputs = processor.apply_chat_template(
	batch_messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
    padding=True,
    padding_side='left',
).to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)
  1. We still have some issues to discuss. I replied to your comment and will generate the final version of the document once it’s completed.

  2. We also added the PaddleOCRVisionConfig and PaddleOCRTextConfig into modular.

Thank you for your efforts. ❤️
PTAL.

@zhang-prog
Copy link
Author

@zucchini-nlp How do I properly add documentation pages and unit tests? I tried to use transformers add-new-model-like, which generates the new modular_xxx.py files, but this process might not be the right approach.

@zhang-prog
Copy link
Author

@zucchini-nlp PTAL. Thanks❤️

@zucchini-nlp
Copy link
Member

Sorry, taking a look.Got lost in my notifications

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, only a few comments and replied to your questions above

For the docs and the tests, they need to be in source/docs/en/model_doc and in tests folder. You can take a look at the recently merged model for an example https://github.com/huggingface/transformers/pull/41112/files#diff-857421affc3c877bca95377cbb6adb3a8374b149fcbdcc6c759ea0408fa96897



logger = logging.get_logger(__name__)
rope_config_validation = partial(_rope_config_validation, ignore_keys={"mrope_section"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had quite a few changes with rope recently. Rebasing main will help to get rid of this. Ignore keys now can be passed to super and rope validation is called in base class :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

But right now, we can only pass ignore_keys_at_rope_validation via kwargs and can’t specify it directly:

kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"}

Maybe setting ignore_keys_at_rope_validation as a default parameter would make it more flexible?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean as a class attriibute or similar? That could be an option yes, my initial idea was to not multiply the number of cls attributes though ignore_keys are static and can be defined once. We're still adjusting stuff and the current way is an rc0.

For this model, we can follow the api in main branch and I will try to make a better way to pass ignore_keys for subsequent releases

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I will remove kwargs["ignore_keys_at_rope_validation"] = {"mrope_section"} firstly ; otherwise, the make style will not pass.

I will adjust this part of the code according to the main branch (your subsequent modifications).



class PaddleOCRVLModel(Qwen2VLModel):
_checkpoint_conversion_mapping = {"^model": "language_model"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need it or is the model state dict converted to correct format?

Copy link
Author

@zhang-prog zhang-prog Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp
We need to maintain the weight mapping because vllm, sglang, and fastdeploy also need to load the current state dict. If we modify the state dict, we would have to submit PRs to these three repositories to change the weight mapping, which would be quite cumbersome.😢

However, I have noticed that the _checkpoint_conversion_mapping attribute was removed a week ago.

Are there any other methods to achieve this besides passing the key_mapping during initialization? We want to ensure that users are unaware of this process.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification for _checkpoint_conversion_mapping is here: #42396

Now, we need to pass in key_mapping to ensure correct inference, which is a bit complicated. Is there a better way to do this?

example:

from transformers import AutoProcessor, AutoModelForImageTextToText

+ key_mapping = {
+    "^visual": "model.visual",
+    "^mlp_AR": "model.projector",
+    r"^model(?!(\.visual|\.projector))": "model.language_model",
+ }

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
- model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
+ model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", key_mapping=key_mapping)
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, no prob with keeping the state dict as is. I was mostly curious if it's a typo or intended :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand:
add paddleocrvl to VLMS in transformers/src/transformers/conversion_mapping.py to use _checkpoint_conversion_mapping

@zhang-prog
Copy link
Author

@zucchini-nlp

PTAL.❤️

_checkpoint_conversion_mapping and ignore_keys_at_rope_validation needs to be discussed.

I am working on the docs and tests.

assert self.temporal_patch_size == 1
flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size)
if temporal_patch_size != 1:
raise ValueError("temporal_patch_size must be 1!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: I'd frame it as "temporal_patch_size must be 1, but got {temporal_patch_size}!" to give more information

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zucchini-nlp
Copy link
Member

Great, looking good already. We can keep the conversion mapping as is, no issue for us! There are also a few unresolved comments from the past iterations, if you can take a look

Ping me when the docs/tests are added and the CI shows ✅

@zhang-prog
Copy link
Author

zhang-prog commented Dec 5, 2025

Don't merge. Working.....

@github-actions
Copy link
Contributor

github-actions bot commented Dec 8, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants