1515from ...tokenization_utils_base import PreTokenizedInput , TextInput
1616
1717
18- class PixtralProcessorKwargs (ProcessingKwargs , total = False ):
18+ class LightOnOCRProcessorKwargs (ProcessingKwargs , total = False ):
1919 _defaults = {
2020 "text_kwargs" : {
2121 "padding" : False ,
2222 "return_mm_token_type_ids" : False ,
2323 },
24+ "images_kwargs" : {
25+ "patch_size" : None , # Will be set from processor config
26+ },
2427 "common_kwargs" : {
2528 "return_tensors" : "pt" ,
2629 },
@@ -106,6 +109,8 @@ def __init__(
106109 ):
107110 self .patch_size = patch_size
108111 self .spatial_merge_size = spatial_merge_size
112+ # Calculate effective patch size for image processing
113+ self .effective_patch_size = patch_size * spatial_merge_size
109114 self .image_token = image_token
110115 self .image_token_id = tokenizer .convert_tokens_to_ids (self .image_token )
111116 self .image_break_token = image_break_token
@@ -114,6 +119,10 @@ def __init__(
114119 self .image_break_token_id = tokenizer .convert_tokens_to_ids (self .image_break_token )
115120 self .image_end_token_id = tokenizer .convert_tokens_to_ids (self .image_end_token )
116121 self .image_ids = [self .image_token_id , self .image_break_token_id , self .image_end_token_id ]
122+
123+ # Set the default patch_size for images_kwargs
124+ LightOnOCRProcessorKwargs ._defaults ["images_kwargs" ]["patch_size" ] = self .effective_patch_size
125+
117126 super ().__init__ (image_processor , tokenizer , chat_template = chat_template )
118127
119128 def __call__ (
@@ -125,14 +134,12 @@ def __call__(
125134 if images is None and text is None :
126135 raise ValueError ("You must provide either text or images" )
127136 output_kwargs = self ._merge_kwargs (
128- PixtralProcessorKwargs ,
137+ LightOnOCRProcessorKwargs ,
129138 tokenizer_init_kwargs = self .tokenizer .init_kwargs ,
130139 ** kwargs ,
131140 )
132141
133- patch_size = self .patch_size * self .spatial_merge_size
134142 if images is not None :
135- output_kwargs ["images_kwargs" ]["patch_size" ] = patch_size
136143 image_inputs = self .image_processor (images , ** output_kwargs ["images_kwargs" ])
137144 else :
138145 image_inputs = {}
@@ -145,8 +152,8 @@ def __call__(
145152 # Expand image token if image is present
146153 if image_inputs .get ("pixel_values" ) is not None :
147154 height , width = image_inputs ["image_sizes" ][0 ]
148- num_height_tokens = height // patch_size
149- num_width_tokens = width // patch_size
155+ num_height_tokens = height // self . effective_patch_size
156+ num_width_tokens = width // self . effective_patch_size
150157 num_patches = num_height_tokens * num_width_tokens
151158
152159 # Replace single image token with repeated tokens
@@ -182,33 +189,26 @@ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
182189 """
183190 vision_data = {}
184191 if image_sizes is not None :
185- images_kwargs = PixtralProcessorKwargs ._defaults .get ("images_kwargs" , {})
192+ images_kwargs = LightOnOCRProcessorKwargs ._defaults .get ("images_kwargs" , {})
186193 images_kwargs .update (kwargs )
187194
188195 size = images_kwargs .get ("size" , None ) or self .image_processor .size
189- patch_size = self .patch_size * self .spatial_merge_size
190196
191197 num_image_tokens = []
192198 for height , width in image_sizes :
193199 resized_height , resized_width = get_resize_output_image_size (
194200 np .zeros ((height , width , 3 )),
195201 size = (size ["longest_edge" ], size ["longest_edge" ]),
196- patch_size = (patch_size , patch_size ),
202+ patch_size = (self . effective_patch_size , self . effective_patch_size ),
197203 )
198- num_height_tokens = resized_height // patch_size
199- num_width_tokens = resized_width // patch_size
204+ num_height_tokens = resized_height // self . effective_patch_size
205+ num_width_tokens = resized_width // self . effective_patch_size
200206 num_image_tokens .append ((num_width_tokens + 1 ) * num_height_tokens )
201207
202208 num_image_patches = [1 ] * len (image_sizes )
203209 vision_data .update ({"num_image_tokens" : num_image_tokens , "num_image_patches" : num_image_patches })
204210
205211 return MultiModalData (** vision_data )
206212
207- @property
208- def model_input_names (self ):
209- tokenizer_input_names = self .tokenizer .model_input_names
210- image_processor_input_names = self .image_processor .model_input_names
211- return tokenizer_input_names + image_processor_input_names + ["image_sizes" ]
212-
213213
214214__all__ = ["LightOnOCRProcessor" ]
0 commit comments