Skip to content

Commit 6dd60a9

Browse files
Reorganize LightOnOCR components to place vision before text and remove debug print
1 parent c7ea243 commit 6dd60a9

File tree

5 files changed

+787
-697
lines changed

5 files changed

+787
-697
lines changed

src/transformers/modeling_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,6 @@ def _get_tied_weight_keys(module: nn.Module, prefix=""):
478478
for name, submodule in module.named_children():
479479
local_prefix = f"{prefix}.{name}" if prefix else name
480480
tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
481-
print(f"tied_weight_keys : {tied_weight_keys}")
482481
return tied_weight_keys
483482

484483

src/transformers/models/lightonocr/configuration_lightonocr.py

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,96 @@
1010
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
1111

1212

13+
class LightOnOCRVisionConfig(PreTrainedConfig):
14+
r"""
15+
This is the configuration class to store the configuration of a [`LightOnOCRVisionModel`]. It is used to instantiate an
16+
LightOnOCR vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
17+
with the defaults will yield a similar configuration to the vision encoder used by LightOnOCR-12B.
18+
19+
e.g. [lightonocr-hf/lightonocr-9b](https://huggingface.co/lightonocr-hf/lightonocr-9b)
20+
21+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
22+
documentation from [`PreTrainedConfig`] for more information.
23+
24+
Args:
25+
hidden_size (`int`, *optional*, defaults to 1024):
26+
Dimension of the hidden representations.
27+
intermediate_size (`int`, *optional*, defaults to 4096):
28+
Dimension of the MLP representations.
29+
num_hidden_layers (`int`, *optional*, defaults to 24):
30+
Number of hidden layers in the Transformer encoder.
31+
num_attention_heads (`int`, *optional*, defaults to 16):
32+
Number of attention heads in the Transformer encoder.
33+
num_channels (`int`, *optional*, defaults to 3):
34+
Number of input channels in the input images.
35+
image_size (`int`, *optional*, defaults to 1024):
36+
Max dimension of the input images.
37+
patch_size (`int`, *optional*, defaults to 16):
38+
Size of the image patches.
39+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
40+
Activation function used in the hidden layers.
41+
attention_dropout (`float`, *optional*, defaults to 0.0):
42+
Dropout probability for the attention layers.
43+
rope_parameters (`RopeParameters`, *optional*):
44+
The RopeParameters
45+
initializer_range (`float`, *optional*, defaults to 0.02):
46+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
47+
48+
Example:
49+
50+
```python
51+
>>> from transformers import LightOnOCRVisionModel, LightOnOCRVisionConfig
52+
53+
>>> # Initializing a LightOnOCR-12B style configuration
54+
>>> config = LightOnOCRVisionConfig()
55+
56+
>>> # Initializing a model (with randomly initialized weights) from the configuration
57+
>>> model = LightOnOCRVisionModel(configuration)
58+
59+
>>> # Accessing the model configuration
60+
>>> configuration = model.config
61+
```"""
62+
63+
model_type = "lightonocr_vision"
64+
65+
def __init__(
66+
self,
67+
hidden_size: Optional[int] = 1024,
68+
intermediate_size: Optional[int] = 4096,
69+
num_hidden_layers: Optional[int] = 24,
70+
num_attention_heads: Optional[int] = 16,
71+
num_channels: Optional[int] = 3,
72+
image_size: Optional[int] = 1024,
73+
patch_size: Optional[int] = 16,
74+
hidden_act: Optional[str] = "gelu",
75+
attention_dropout: Optional[float] = 0.0,
76+
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
77+
initializer_range: Optional[float] = 0.02,
78+
**kwargs,
79+
):
80+
super().__init__(**kwargs)
81+
82+
self.hidden_size = hidden_size
83+
self.intermediate_size = intermediate_size
84+
self.num_hidden_layers = num_hidden_layers
85+
self.num_attention_heads = num_attention_heads
86+
self.num_channels = num_channels
87+
self.patch_size = patch_size
88+
self.image_size = image_size
89+
self.attention_dropout = attention_dropout
90+
self.hidden_act = hidden_act
91+
self.head_dim = hidden_size // num_attention_heads
92+
self.initializer_range = initializer_range
93+
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
94+
rope_scaling = kwargs.pop("rope_scaling", None)
95+
self.rope_parameters = rope_scaling or rope_parameters
96+
97+
# Validate the correctness of rotary position embeddings parameters
98+
rope_theta = kwargs.get("rope_theta", 10000.0)
99+
standardize_rope_params(self, rope_theta=rope_theta)
100+
rope_config_validation(self)
101+
102+
13103
class LightOnOCRTextConfig(PreTrainedConfig):
14104
r"""
15105
This is the configuration class to store the configuration of a [`LightOnOCRTextModel`]. It is used to instantiate a
@@ -176,96 +266,6 @@ def __init__(
176266
)
177267

178268

179-
class LightOnOCRVisionConfig(PreTrainedConfig):
180-
r"""
181-
This is the configuration class to store the configuration of a [`LightOnOCRVisionModel`]. It is used to instantiate an
182-
LightOnOCR vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
183-
with the defaults will yield a similar configuration to the vision encoder used by LightOnOCR-12B.
184-
185-
e.g. [lightonocr-hf/lightonocr-9b](https://huggingface.co/lightonocr-hf/lightonocr-9b)
186-
187-
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
188-
documentation from [`PreTrainedConfig`] for more information.
189-
190-
Args:
191-
hidden_size (`int`, *optional*, defaults to 1024):
192-
Dimension of the hidden representations.
193-
intermediate_size (`int`, *optional*, defaults to 4096):
194-
Dimension of the MLP representations.
195-
num_hidden_layers (`int`, *optional*, defaults to 24):
196-
Number of hidden layers in the Transformer encoder.
197-
num_attention_heads (`int`, *optional*, defaults to 16):
198-
Number of attention heads in the Transformer encoder.
199-
num_channels (`int`, *optional*, defaults to 3):
200-
Number of input channels in the input images.
201-
image_size (`int`, *optional*, defaults to 1024):
202-
Max dimension of the input images.
203-
patch_size (`int`, *optional*, defaults to 16):
204-
Size of the image patches.
205-
hidden_act (`str`, *optional*, defaults to `"gelu"`):
206-
Activation function used in the hidden layers.
207-
attention_dropout (`float`, *optional*, defaults to 0.0):
208-
Dropout probability for the attention layers.
209-
rope_parameters (`RopeParameters`, *optional*):
210-
The RopeParameters
211-
initializer_range (`float`, *optional*, defaults to 0.02):
212-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
213-
214-
Example:
215-
216-
```python
217-
>>> from transformers import LightOnOCRVisionModel, LightOnOCRVisionConfig
218-
219-
>>> # Initializing a LightOnOCR-12B style configuration
220-
>>> config = LightOnOCRVisionConfig()
221-
222-
>>> # Initializing a model (with randomly initialized weights) from the configuration
223-
>>> model = LightOnOCRVisionModel(configuration)
224-
225-
>>> # Accessing the model configuration
226-
>>> configuration = model.config
227-
```"""
228-
229-
model_type = "lightonocr_vision"
230-
231-
def __init__(
232-
self,
233-
hidden_size: Optional[int] = 1024,
234-
intermediate_size: Optional[int] = 4096,
235-
num_hidden_layers: Optional[int] = 24,
236-
num_attention_heads: Optional[int] = 16,
237-
num_channels: Optional[int] = 3,
238-
image_size: Optional[int] = 1024,
239-
patch_size: Optional[int] = 16,
240-
hidden_act: Optional[str] = "gelu",
241-
attention_dropout: Optional[float] = 0.0,
242-
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
243-
initializer_range: Optional[float] = 0.02,
244-
**kwargs,
245-
):
246-
super().__init__(**kwargs)
247-
248-
self.hidden_size = hidden_size
249-
self.intermediate_size = intermediate_size
250-
self.num_hidden_layers = num_hidden_layers
251-
self.num_attention_heads = num_attention_heads
252-
self.num_channels = num_channels
253-
self.patch_size = patch_size
254-
self.image_size = image_size
255-
self.attention_dropout = attention_dropout
256-
self.hidden_act = hidden_act
257-
self.head_dim = hidden_size // num_attention_heads
258-
self.initializer_range = initializer_range
259-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
260-
rope_scaling = kwargs.pop("rope_scaling", None)
261-
self.rope_parameters = rope_scaling or rope_parameters
262-
263-
# Validate the correctness of rotary position embeddings parameters
264-
rope_theta = kwargs.get("rope_theta", 10000.0)
265-
standardize_rope_params(self, rope_theta=rope_theta)
266-
rope_config_validation(self)
267-
268-
269269
class LightOnOCRConfig(PretrainedConfig):
270270
r"""
271271
This is the configuration class to store the configuration of a [`LightOnOCRForConditionalGeneration`]. It is used to instantiate a

0 commit comments

Comments
 (0)