Skip to content

Commit 0aa59e1

Browse files
junjiang-labcopybara-github
authored andcommitted
Add a flag "cast_i64_inputs_to_i32" to cast all i64 tensor to i32.
PiperOrigin-RevId: 791393616
1 parent 46ff6d0 commit 0aa59e1

File tree

16 files changed

+476
-12
lines changed

16 files changed

+476
-12
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
def _run_convert_passes(
3434
exported_program: torch.export.ExportedProgram,
35+
cast_i64_inputs_to_i32: bool,
3536
) -> torch.export.ExportedProgram:
3637
exported_program = generative_fx_passes.run_generative_passes(
3738
exported_program
@@ -46,6 +47,10 @@ def _run_convert_passes(
4647
fx_passes.CastInputsBf16ToF32Pass(),
4748
]
4849

50+
if cast_i64_inputs_to_i32:
51+
print("---------------> Casting i64 inputs to i32")
52+
passes += [fx_passes.CastInputsI64ToI32Pass()]
53+
4954
# Debuginfo is not injected automatically by odml_torch. Only inject
5055
# debuginfo via fx pass when using torch_xla.
5156
if ai_edge_torch.config.use_torch_xla:
@@ -82,6 +87,7 @@ def convert_signatures(
8287
signatures: list[signature.Signature],
8388
*,
8489
strict_export: Union[Literal["auto"], bool] = True,
90+
cast_i64_inputs_to_i32: bool = False,
8591
quant_config: Optional[qcfg.QuantConfig] = None,
8692
_tfl_converter_flags: Optional[dict[str, Any]] = None,
8793
_saved_model_dir: Optional[str] = None,
@@ -96,6 +102,8 @@ def convert_signatures(
96102
and ensure the soundness of the exported graph. When
97103
strict_export="auto", the function will try to export module in both
98104
modes and use the first one succeeds for downstream conversion.
105+
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
106+
torch.int32.
99107
quant_config: User-defined quantization method and scheme of the model.
100108
_tfl_converter_flags: A nested dictionary allowing setting flags for the
101109
underlying tflite converter.
@@ -147,7 +155,10 @@ def export(**kwargs):
147155
]
148156

149157
# Apply default fx passes
150-
exported_programs = list(map(_run_convert_passes, exported_programs))
158+
exported_programs = [
159+
_run_convert_passes(ep, cast_i64_inputs_to_i32)
160+
for ep in exported_programs
161+
]
151162
tflite_model = lowertools.exported_programs_to_tflite(
152163
exported_programs,
153164
signatures,

ai_edge_torch/_convert/converter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def convert(
132132
sample_kwargs=None,
133133
*,
134134
strict_export: Union[Literal["auto"], bool] = True,
135+
cast_i64_inputs_to_i32: bool = False,
135136
quant_config: Optional[qcfg.QuantConfig] = None,
136137
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
137138
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -159,6 +160,8 @@ def convert(
159160
and ensure the soundness of the exported graph. When
160161
strict_export="auto", the function will try to export module in both
161162
modes and use the first one succeeds for downstream conversion.
163+
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
164+
torch.int32.
162165
quant_config: User-defined quantization method and scheme of the model.
163166
dynamic_shapes: Optional dict or tuple that specify dynamic shape
164167
specifications for each input in original order. See
@@ -203,6 +206,7 @@ def convert(
203206
converted_model = conversion.convert_signatures(
204207
self._signatures,
205208
strict_export=strict_export,
209+
cast_i64_inputs_to_i32=cast_i64_inputs_to_i32,
206210
quant_config=quant_config,
207211
_tfl_converter_flags=_ai_edge_converter_flags,
208212
_saved_model_dir=_saved_model_dir,
@@ -271,6 +275,7 @@ def convert(
271275
sample_kwargs=None,
272276
*,
273277
strict_export: Union[Literal["auto"], bool] = True,
278+
cast_i64_inputs_to_i32: bool = False,
274279
quant_config: Optional[qcfg.QuantConfig] = None,
275280
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
276281
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -289,6 +294,8 @@ def convert(
289294
and ensure the soundness of the exported graph. When strict_export="auto",
290295
the function will try to export module in both modes and use the first one
291296
succeeds for downstream conversion.
297+
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
298+
torch.int32.
292299
quant_config: User-defined quantization method and scheme of the model.
293300
dynamic_shapes: Optional dict or tuple that specify dynamic shape
294301
specifications for each input in original order. See
@@ -317,6 +324,7 @@ def convert(
317324
sample_args,
318325
sample_kwargs,
319326
strict_export=strict_export,
327+
cast_i64_inputs_to_i32=cast_i64_inputs_to_i32,
320328
quant_config=quant_config,
321329
dynamic_shapes=dynamic_shapes,
322330
_ai_edge_converter_flags=_ai_edge_converter_flags,

ai_edge_torch/_convert/fx_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
1919
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20+
from ai_edge_torch._convert.fx_passes.cast_inputs_i64_to_i32_pass import CastInputsI64ToI32Pass
2021
from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
2122
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
2223
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass

ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable
1717
from ai_edge_torch import fx_infra
1818
from ai_edge_torch import lowertools
19+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
1920
import torch
2021
import torch.utils._pytree as pytree
2122

@@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
276277
# Explicitly reshape back to the original shape. This places the ReshapeOp
277278
# outside of the HLFB.
278279
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
280+
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
279281
return output
280282

281283
node.target = embedding
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Pass to cast all inputs with torch.int64 type to torch.int32."""
16+
17+
18+
from ai_edge_torch import fx_infra
19+
import torch
20+
21+
22+
def cast_i32(x):
23+
# return x.to(torch.int32)
24+
return x.to(torch.float32)
25+
26+
27+
class CastInputsI64ToI32Pass(fx_infra.ExportedProgramPassBase):
28+
"""This pass casts all inputs with torch.int64 type to torch.int32."""
29+
30+
def call(self, exported_program: torch.export.ExportedProgram):
31+
modified = False
32+
for node in exported_program.graph.nodes:
33+
if (
34+
node.op in ("placeholder", "call_function")
35+
and node.meta.get("val") is not None
36+
and node.meta.get("val").dtype == torch.int64
37+
):
38+
if not node.users:
39+
continue
40+
41+
modified = True
42+
user = next(iter(node.users))
43+
with exported_program.graph.inserting_before(user):
44+
cast_node = exported_program.graph.call_function(
45+
cast_i32,
46+
(node,),
47+
)
48+
node.replace_all_uses_with(cast_node)
49+
cast_node.replace_input_with(cast_node, node)
50+
51+
exported_program.graph_module.recompile()
52+
return fx_infra.ExportedProgramPassResult(exported_program, modified)

ai_edge_torch/_convert/test/test_convert.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from typing import Tuple
2020

2121
import ai_edge_torch
22+
from ai_edge_torch import fx_infra
2223
from ai_edge_torch._convert import conversion_utils
24+
from ai_edge_torch.odml_torch.experimental import torch_tfl
2325
from ai_edge_torch.quantize import pt2e_quantizer
2426
from ai_edge_torch.testing import model_coverage
2527
import numpy as np
@@ -576,6 +578,39 @@ def forward(self, x: torch.Tensor):
576578
self.fail(f"Conversion failed with bloat16 inputs: {err}")
577579
# pylint: enable=broad-except
578580

581+
def test_convert_model_with_i64_inputs_legalization_error(self):
582+
"""Test converting a simple model with torch.int64 input.
583+
584+
i64 inputs would remain in converted model signature but be casted to i32
585+
right after the model inputs.
586+
"""
587+
588+
class SampleModel(nn.Module):
589+
590+
def forward(self, x: torch.Tensor):
591+
return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)
592+
593+
model = SampleModel().eval()
594+
args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)
595+
596+
# pylint: disable=broad-except
597+
try:
598+
# Expect this to potentially raise an error during conversion
599+
ai_edge_torch.convert(model, args, cast_i64_inputs_to_i32=False)
600+
self.fail("Conversion succeeded unexpectedly")
601+
except Exception as err:
602+
print(f"Conversion failed as expected: {err}")
603+
expected_error_message = "failed to legalize operation 'tfl.less'"
604+
if expected_error_message not in str(err):
605+
self.fail(f"Unexpected error message: {err}")
606+
607+
try:
608+
# Expect this to fix the error during conversion
609+
ai_edge_torch.convert(model, args, cast_i64_inputs_to_i32=True)
610+
except Exception as err:
611+
self.fail(f"Conversion failed with int64 inputs: {err}")
612+
# pylint: enable=broad-except
613+
579614
def test_compile_model(self):
580615
"""Tests AOT compilation of a simple Add module."""
581616

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from torch import nn
2525

26-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26+
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
2727
ff_up_proj="model.layers.{}.mlp.up_proj",
2828
ff_down_proj="model.layers.{}.mlp.down_proj",
2929
ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -36,6 +36,24 @@
3636
lm_head=None,
3737
)
3838

39+
TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
40+
ff_up_proj="model.layers.{}.mlp.up_proj",
41+
ff_down_proj="model.layers.{}.mlp.down_proj",
42+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
43+
attn_query_proj="model.layers.{}.self_attn.q_proj",
44+
attn_key_proj="model.layers.{}.self_attn.k_proj",
45+
attn_value_proj="model.layers.{}.self_attn.v_proj",
46+
attn_output_proj="model.layers.{}.self_attn.o_proj",
47+
pre_attn_norm="model.layers.{}.input_layernorm",
48+
post_attn_norm="model.layers.{}.post_attention_layernorm",
49+
embedding="model.embed_tokens",
50+
final_norm="model.norm",
51+
)
52+
53+
TENSOR_NAMES_DICT = {
54+
"safetensors": TENSOR_NAMES_SEP_QKV,
55+
"kaggle": TENSOR_NAMES_FUSED_QKV,
56+
}
3957

4058
class Gemma1(model_builder.DecoderOnlyModel):
4159
"""A Gemma1 model built from the Edge Generative API layers."""
@@ -94,11 +112,28 @@ def build_2b_model(
94112
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
95113
mask_cache_size: int = 0,
96114
) -> nn.Module:
97-
return model_builder.build_decoder_only_model(
98-
checkpoint_path=checkpoint_path,
99-
config=get_model_config_2b(),
100-
tensor_names=TENSOR_NAMES,
101-
model_class=Gemma1,
102-
custom_loader=custom_loader,
103-
mask_cache_size=mask_cache_size,
115+
116+
# A list to store the reasons for each failure
117+
key_errors = []
118+
119+
for tensor_names in TENSOR_NAMES_DICT.values():
120+
try:
121+
return model_builder.build_decoder_only_model(
122+
checkpoint_path=checkpoint_path,
123+
config=get_model_config_2b(),
124+
tensor_names=tensor_names,
125+
model_class=Gemma1,
126+
custom_loader=custom_loader,
127+
mask_cache_size=mask_cache_size,
128+
)
129+
except KeyError as ke:
130+
# Store the specific key that was missing for later
131+
key_errors.append(f"Missing key: {ke}")
132+
continue
133+
134+
# If the loop finishes, raise an error with all the collected details
135+
error_details = "\n".join(key_errors)
136+
raise RuntimeError(
137+
"Failed to build model after trying all configurations. "
138+
f"Encountered the following errors:\n{error_details}"
104139
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Verifies the reauthored SmolVLM2 Image Encoder model."""
16+
17+
import logging
18+
19+
from absl import app
20+
from absl import flags
21+
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
22+
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
23+
from PIL import Image
24+
import requests
25+
import torch
26+
import transformers
27+
28+
_IMAGE_URL = flags.DEFINE_string(
29+
"image_url",
30+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31+
"The image URI to encode.",
32+
)
33+
34+
_CHECKPOINT = flags.DEFINE_string(
35+
"checkpoint",
36+
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
37+
"The checkpoint to verify.",
38+
)
39+
40+
_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
41+
"pretrained_weights_path",
42+
None,
43+
"The path to the model's pretrained weights.",
44+
)
45+
46+
47+
def main(_):
48+
checkpoint = _CHECKPOINT.value
49+
logging.info("Loading the original model from: %s", checkpoint)
50+
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
51+
checkpoint
52+
)
53+
original_model = original_model.eval().model
54+
55+
logging.info("Building the reauthored checkpoint from: %s", checkpoint)
56+
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
57+
if reauthored_checkpoint is None:
58+
raise ValueError("reauthored_checkpoint is required.")
59+
60+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
61+
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)
62+
63+
logging.info("Loading the tokenizer from: %s", checkpoint)
64+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
65+
66+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
67+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
68+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
69+
70+
logging.info("Forwarding the original model...")
71+
outputs_original = original_model.get_image_features(pixel_values)
72+
logging.info("outputs_original's shape: %s", outputs_original.shape)
73+
74+
pixel_values = pixel_values.reshape(
75+
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
76+
)
77+
logging.info("Forwarding the reauthored model...")
78+
outputs_reauthored = reauthored_model.forward(
79+
pixel_values=pixel_values
80+
)
81+
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)
82+
83+
try:
84+
assert torch.allclose(
85+
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
86+
)
87+
except AssertionError as e:
88+
logging.error("*** FAILED *** verify with an image")
89+
raise e
90+
else:
91+
logging.info("*** PASSED *** verify with an image")
92+
93+
94+
if __name__ == "__main__":
95+
app.run(main)

0 commit comments

Comments
 (0)