Skip to content

Commit 6f07906

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add numeric verification for the SmolVLM2 Image Encoder Model.
PiperOrigin-RevId: 790259698
1 parent 46ff6d0 commit 6f07906

File tree

9 files changed

+323
-1
lines changed

9 files changed

+323
-1
lines changed

ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable
1717
from ai_edge_torch import fx_infra
1818
from ai_edge_torch import lowertools
19+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
1920
import torch
2021
import torch.utils._pytree as pytree
2122

@@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
276277
# Explicitly reshape back to the original shape. This places the ReshapeOp
277278
# outside of the HLFB.
278279
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
280+
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
279281
return output
280282

281283
node.target = embedding
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Verifies the reauthored SmolVLM2 Image Encoder model."""
16+
17+
import logging
18+
19+
from absl import app
20+
from absl import flags
21+
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
22+
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
23+
from PIL import Image
24+
import requests
25+
import torch
26+
import transformers
27+
28+
_IMAGE_URL = flags.DEFINE_string(
29+
"image_url",
30+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31+
"The image URI to encode.",
32+
)
33+
34+
_CHECKPOINT = flags.DEFINE_string(
35+
"checkpoint",
36+
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
37+
"The checkpoint to verify.",
38+
)
39+
40+
_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
41+
"pretrained_weights_path",
42+
None,
43+
"The path to the model's pretrained weights.",
44+
)
45+
46+
47+
def main(_):
48+
checkpoint = _CHECKPOINT.value
49+
logging.info("Loading the original model from: %s", checkpoint)
50+
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
51+
checkpoint
52+
)
53+
original_model = original_model.eval().model
54+
55+
logging.info("Building the reauthored checkpoint from: %s", checkpoint)
56+
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
57+
if reauthored_checkpoint is None:
58+
raise ValueError("reauthored_checkpoint is required.")
59+
60+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
61+
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)
62+
63+
logging.info("Loading the tokenizer from: %s", checkpoint)
64+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
65+
66+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
67+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
68+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
69+
70+
logging.info("Forwarding the original model...")
71+
outputs_original = original_model.get_image_features(pixel_values)
72+
logging.info("outputs_original's shape: %s", outputs_original.shape)
73+
74+
pixel_values = pixel_values.reshape(
75+
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
76+
)
77+
logging.info("Forwarding the reauthored model...")
78+
outputs_reauthored = reauthored_model.forward(
79+
pixel_values=pixel_values
80+
)
81+
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)
82+
83+
try:
84+
assert torch.allclose(
85+
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
86+
)
87+
except AssertionError as e:
88+
logging.error("*** FAILED *** verify with an image")
89+
raise e
90+
else:
91+
logging.info("*** PASSED *** verify with an image")
92+
93+
94+
if __name__ == "__main__":
95+
app.run(main)

ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,18 @@ def forward(
129129
pixel_values: torch.Tensor,
130130
export_config: export_cfg.ExportConfig = None,
131131
) -> torch.Tensor:
132-
x = self.siglip_encoder(pixel_values)
132+
# Embed the image according to SiplipVisionEmbeddings.
133+
x = self.siglip_encoder.tok_embedding(pixel_values)
134+
x = x.flatten(2).transpose(1, 2)
135+
x = x + self.siglip_encoder.tok_embedding_position
136+
137+
# Pass a dummy mask because SDPA attention impl expects non-None mask.
138+
mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1])
139+
for _, block in enumerate(self.siglip_encoder.transformer_blocks):
140+
x = block(x, mask=mask)
141+
x = self.siglip_encoder.final_norm(x)
142+
143+
# Project the image embeddings to text hidden size.
133144
x = self.connector(x)
134145
return x
135146

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x):
180180
return torch.ops.tfl.rsqrt(x)
181181

182182

183+
@register_decomp(torch.ops.aten.neg.default)
184+
def _aten_neg_decomp(x):
185+
return torch.ops.tfl.neg(x)
186+
187+
183188
@register_decomp(torch.ops.aten.gelu.default)
184189
def _aten_gelu_decomp(x, approximate="none"):
185190
return torch.ops.tfl.gelu(x, approximate != "none")
@@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index):
317322
return torch.ops.tfl.squeeze(sliced, [dim])
318323

319324

325+
@register_decomp(torch.ops.aten.slice.Tensor)
326+
def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
327+
rank = x.dim()
328+
dim_size = x.shape[dim]
329+
330+
# Initialize begin, end, strides for tfl.strided_slice
331+
begin = [0] * rank
332+
end_vec = list(x.shape)
333+
strides = [1] * rank
334+
335+
# The logic below is to match PyTorch's `slice` behavior.
336+
# `start` and `end` can be negative, which means they count from the end.
337+
# `start=None` defaults to 0.
338+
# `end=None` or a large number defaults to `dim_size` after clamping.
339+
340+
start_val = 0 if start is None else start
341+
if start_val < 0:
342+
start_val += dim_size
343+
344+
end_val = dim_size if end is None else end
345+
if end_val < 0:
346+
end_val += dim_size
347+
348+
# Clamp start and end to be within the dimension size, following PyTorch's
349+
# logic.
350+
start_val = max(0, min(start_val, dim_size))
351+
end_val = max(start_val, min(end_val, dim_size))
352+
353+
begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
354+
return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)
355+
356+
320357
@register_decomp(torch.ops.aten.where.self)
321358
def _aten_where_self_decomp(condition, x, y):
322359
x, y = _promote_types_for_binary_op(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering(
286286
)
287287

288288

289+
@lower(torch.ops.tfl.neg.default)
290+
def _tfl_neg_lowering(
291+
lctx: LoweringContext,
292+
x: ir.Value,
293+
) -> ir.Value:
294+
return _ir_operation(
295+
"tfl.neg",
296+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
297+
operands=[x],
298+
)
299+
300+
289301
@lower(torch.ops.tfl.gelu.default)
290302
def _tfl_gelu_lowering(
291303
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor:
110110
return torch.rsqrt(x)
111111

112112

113+
@custom_op_with_fake("tfl::neg")
114+
def tfl_neg(x: torch.Tensor) -> torch.Tensor:
115+
return torch.neg(x)
116+
117+
113118
@custom_op_with_fake("tfl::gelu")
114119
def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor:
115120
gelu_approximate = "tanh" if approximate else "none"

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _assert_export_and_close(
152152
("aten_cos_1", torch.ops.aten.cos.default, (rnd(torch.float32, (1, 10)),), dict()),
153153
("aten_rsqrt_0", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (10, 10)),), dict()),
154154
("aten_rsqrt_1", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (1, 10)),), dict()),
155+
("aten_neg_0", torch.ops.aten.neg.default, (rnd(torch.float32, (10, 10)),), dict()),
155156
("aten_gelu_0", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict()),
156157
("aten_gelu_1", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict(approximate="tanh")),
157158
("aten_gelu_2", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict()),
@@ -186,6 +187,14 @@ def _assert_export_and_close(
186187
("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()),
187188
("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()),
188189
("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()),
190+
("aten_slice_tensor_0", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=3)),
191+
("aten_slice_tensor_1", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=5)),
192+
("aten_slice_tensor_2", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=5)),
193+
("aten_slice_tensor_3", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=2, end=None)),
194+
("aten_slice_tensor_4", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=-5, end=-2)),
195+
("aten_slice_tensor_5", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=8, step=2)),
196+
("aten_slice_tensor_6", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=100)),
197+
("aten_slice_tensor_7", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=None)),
189198
("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
190199
("aten_embedding_0", torch.ops.aten.embedding.default, (rnd(torch.float32, (10, 10)), torch.tensor([[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]),), dict()),
191200
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Optimization barrier op definition and lowering."""
16+
17+
from ai_edge_torch.odml_torch import _torch_library
18+
from ai_edge_torch.odml_torch.lowerings import registry
19+
from jax._src.lib.mlir import ir
20+
from jax._src.lib.mlir.dialects import hlo as stablehlo
21+
import torch
22+
import torch.utils._pytree as pytree
23+
24+
_torch_library.ODML_TORCH_LIB.define(
25+
"optimization_barrier(Tensor[] inputs) -> Tensor[]"
26+
)
27+
28+
optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default
29+
30+
31+
def optimization_barrier(*inputs: pytree.PyTree):
32+
"""Apply optimization barrier to the tensors nested within arbitrary pytrees.
33+
34+
Args:
35+
*inputs: A list of tensors or tensor pytrees.
36+
37+
Returns:
38+
The tensors after optimization barrier in the same pytrees structures.
39+
"""
40+
if len(inputs) == 1:
41+
inputs = inputs[0]
42+
tensors, spec = pytree.tree_flatten(inputs)
43+
tensors = optimization_barrier_op(tuple(tensors))
44+
outputs = pytree.tree_unflatten(tensors, spec)
45+
return outputs
46+
47+
48+
@torch.library.impl(
49+
_torch_library.ODML_TORCH_LIB,
50+
"optimization_barrier",
51+
"CompositeExplicitAutograd",
52+
)
53+
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
54+
return tuple(inputs)
55+
56+
57+
@torch.library.impl(
58+
_torch_library.ODML_TORCH_LIB,
59+
"optimization_barrier",
60+
"Meta",
61+
)
62+
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
63+
return tuple([torch.empty_like(x) for x in inputs])
64+
65+
66+
@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
67+
def _optimization_barrier_lowering(
68+
lctx, inputs: tuple[ir.Value, ...]
69+
) -> ir.Value:
70+
del lctx
71+
return stablehlo.optimization_barrier(inputs)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from ai_edge_torch import odml_torch
16+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op.
17+
import torch
18+
19+
from absl.testing import absltest as googletest
20+
21+
optimization_barrier = optimization_barrier_lib.optimization_barrier
22+
23+
24+
class TestOptimizationBarrier(googletest.TestCase):
25+
"""Test optimization barrier op implementation and lowering."""
26+
27+
def test_applied_optimization_barrier_op(self):
28+
"""Test optimization barrier op application and lowering."""
29+
30+
class TestModel(torch.nn.Module):
31+
32+
def forward(self, x, y):
33+
x, _ = optimization_barrier(x, y)
34+
return x
35+
36+
x = torch.randn(1, 5)
37+
ep = torch.export.export(TestModel().eval(), (x, x))
38+
mlir = odml_torch.export.exported_program_to_mlir(ep)
39+
mlir_text = mlir.get_text()
40+
self.assertEqual(
41+
mlir_text.count(
42+
"stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>,"
43+
" tensor<1x5xf32>"
44+
),
45+
1,
46+
)
47+
48+
def test_input_single_tensor(self):
49+
"""Test optimization barrier with single tensor input."""
50+
x = torch.randn(1, 5)
51+
y = optimization_barrier(x)
52+
self.assertIsInstance(y, torch.Tensor)
53+
self.assertEqual(y.shape, (1, 5))
54+
55+
def test_input_multiple_tensors(self):
56+
"""Test optimization barrier with multiple tensors input."""
57+
x = torch.randn(1, 5)
58+
y = torch.randn(1, 6)
59+
z = optimization_barrier(x, y)
60+
self.assertIsInstance(z, tuple)
61+
self.assertLen(z, 2)
62+
self.assertIsInstance(z[0], torch.Tensor)
63+
self.assertIsInstance(z[1], torch.Tensor)
64+
self.assertEqual(z[0].shape, (1, 5))
65+
self.assertEqual(z[1].shape, (1, 6))
66+
67+
def test_input_nested_tensors(self):
68+
"""Test optimization barrier with nested tensor inputs."""
69+
x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)}
70+
z = optimization_barrier(x)
71+
self.assertIsInstance(z, dict)
72+
self.assertLen(z, 2)
73+
self.assertIsInstance(z["foo"], torch.Tensor)
74+
self.assertIsInstance(z["bar"], torch.Tensor)
75+
self.assertEqual(z["foo"].shape, (1, 5))
76+
self.assertEqual(z["bar"].shape, (1, 6))
77+
78+
79+
if __name__ == "__main__":
80+
googletest.main()

0 commit comments

Comments
 (0)