Skip to content

Commit 25aca01

Browse files
junjiang-labcopybara-github
authored andcommitted
Add jax lowering to cast f64 to f32 for aten ops add, sub, lt, mul, where.
PiperOrigin-RevId: 790985881
1 parent 46ff6d0 commit 25aca01

File tree

13 files changed

+594
-26
lines changed

13 files changed

+594
-26
lines changed

ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable
1717
from ai_edge_torch import fx_infra
1818
from ai_edge_torch import lowertools
19+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
1920
import torch
2021
import torch.utils._pytree as pytree
2122

@@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
276277
# Explicitly reshape back to the original shape. This places the ReshapeOp
277278
# outside of the HLFB.
278279
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
280+
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
279281
return output
280282

281283
node.target = embedding

ai_edge_torch/_convert/test/test_convert.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,24 @@ def forward(self, x: torch.Tensor):
576576
self.fail(f"Conversion failed with bloat16 inputs: {err}")
577577
# pylint: enable=broad-except
578578

579+
def test_convert_model_with_torch_linspace_operation(self):
580+
"""Test converting a simple model with torch.linspace operation."""
581+
582+
class SampleModel(nn.Module):
583+
584+
def forward(self, x: torch.Tensor):
585+
return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)
586+
587+
model = SampleModel().eval()
588+
args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)
589+
590+
try:
591+
# Expect this to fix the error during conversion
592+
ai_edge_torch.convert(model, args)
593+
except Exception as err:
594+
self.fail(f"Conversion failed with int64 inputs: {err}")
595+
# pylint: enable=broad-except
596+
579597
def test_compile_model(self):
580598
"""Tests AOT compilation of a simple Add module."""
581599

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from torch import nn
2525

26-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26+
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
2727
ff_up_proj="model.layers.{}.mlp.up_proj",
2828
ff_down_proj="model.layers.{}.mlp.down_proj",
2929
ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -36,6 +36,24 @@
3636
lm_head=None,
3737
)
3838

39+
TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
40+
ff_up_proj="model.layers.{}.mlp.up_proj",
41+
ff_down_proj="model.layers.{}.mlp.down_proj",
42+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
43+
attn_query_proj="model.layers.{}.self_attn.q_proj",
44+
attn_key_proj="model.layers.{}.self_attn.k_proj",
45+
attn_value_proj="model.layers.{}.self_attn.v_proj",
46+
attn_output_proj="model.layers.{}.self_attn.o_proj",
47+
pre_attn_norm="model.layers.{}.input_layernorm",
48+
post_attn_norm="model.layers.{}.post_attention_layernorm",
49+
embedding="model.embed_tokens",
50+
final_norm="model.norm",
51+
)
52+
53+
TENSOR_NAMES_DICT = {
54+
"safetensors": TENSOR_NAMES_SEP_QKV,
55+
"kaggle": TENSOR_NAMES_FUSED_QKV,
56+
}
3957

4058
class Gemma1(model_builder.DecoderOnlyModel):
4159
"""A Gemma1 model built from the Edge Generative API layers."""
@@ -94,11 +112,28 @@ def build_2b_model(
94112
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
95113
mask_cache_size: int = 0,
96114
) -> nn.Module:
97-
return model_builder.build_decoder_only_model(
98-
checkpoint_path=checkpoint_path,
99-
config=get_model_config_2b(),
100-
tensor_names=TENSOR_NAMES,
101-
model_class=Gemma1,
102-
custom_loader=custom_loader,
103-
mask_cache_size=mask_cache_size,
115+
116+
# A list to store the reasons for each failure
117+
key_errors = []
118+
119+
for tensor_names in TENSOR_NAMES_DICT.values():
120+
try:
121+
return model_builder.build_decoder_only_model(
122+
checkpoint_path=checkpoint_path,
123+
config=get_model_config_2b(),
124+
tensor_names=tensor_names,
125+
model_class=Gemma1,
126+
custom_loader=custom_loader,
127+
mask_cache_size=mask_cache_size,
128+
)
129+
except KeyError as ke:
130+
# Store the specific key that was missing for later
131+
key_errors.append(f"Missing key: {ke}")
132+
continue
133+
134+
# If the loop finishes, raise an error with all the collected details
135+
error_details = "\n".join(key_errors)
136+
raise RuntimeError(
137+
"Failed to build model after trying all configurations. "
138+
f"Encountered the following errors:\n{error_details}"
104139
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Verifies the reauthored SmolVLM2 Image Encoder model."""
16+
17+
import logging
18+
19+
from absl import app
20+
from absl import flags
21+
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
22+
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
23+
from PIL import Image
24+
import requests
25+
import torch
26+
import transformers
27+
28+
_IMAGE_URL = flags.DEFINE_string(
29+
"image_url",
30+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31+
"The image URI to encode.",
32+
)
33+
34+
_CHECKPOINT = flags.DEFINE_string(
35+
"checkpoint",
36+
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
37+
"The checkpoint to verify.",
38+
)
39+
40+
_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
41+
"pretrained_weights",
42+
None,
43+
"The path to the model's pretrained weights.",
44+
)
45+
46+
47+
def main(_):
48+
checkpoint = _CHECKPOINT.value
49+
logging.info("Loading the original model from: %s", checkpoint)
50+
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
51+
checkpoint
52+
)
53+
original_model = original_model.eval().model
54+
55+
logging.info("Building the reauthored checkpoint from: %s", checkpoint)
56+
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
57+
if reauthored_checkpoint is None:
58+
raise ValueError("reauthored_checkpoint is required.")
59+
60+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
61+
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)
62+
63+
logging.info("Loading the tokenizer from: %s", checkpoint)
64+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
65+
66+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
67+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
68+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
69+
70+
logging.info("Forwarding the original model...")
71+
outputs_original = original_model.get_image_features(pixel_values)
72+
logging.info("outputs_original's shape: %s", outputs_original.shape)
73+
74+
pixel_values = pixel_values.reshape(
75+
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
76+
)
77+
logging.info("Forwarding the reauthored model...")
78+
outputs_reauthored = reauthored_model.forward(
79+
pixel_values=pixel_values
80+
)
81+
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)
82+
83+
try:
84+
assert torch.allclose(
85+
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
86+
)
87+
except AssertionError as e:
88+
logging.error("*** FAILED *** verify with an image")
89+
raise e
90+
else:
91+
logging.info("*** PASSED *** verify with an image")
92+
93+
94+
if __name__ == "__main__":
95+
app.run(main)

ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"""
2121

2222
from dataclasses import dataclass
23-
from typing import Callable, Dict
23+
from typing import Callable, Dict, Optional
2424

2525
from ai_edge_torch.generative.examples.paligemma import image_encoder
2626
import ai_edge_torch.generative.layers.model_config as cfg
@@ -127,9 +127,20 @@ def __init__(
127127
def forward(
128128
self,
129129
pixel_values: torch.Tensor,
130-
export_config: export_cfg.ExportConfig = None,
130+
export_config: Optional[export_cfg.ExportConfig] = None,
131131
) -> torch.Tensor:
132-
x = self.siglip_encoder(pixel_values)
132+
# Embed the image according to SiplipVisionEmbeddings.
133+
x = self.siglip_encoder.tok_embedding(pixel_values)
134+
x = x.flatten(2).transpose(1, 2)
135+
x = x + self.siglip_encoder.tok_embedding_position
136+
137+
# Pass a dummy mask because SDPA attention impl expects non-None mask.
138+
mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1])
139+
for _, block in enumerate(self.siglip_encoder.transformer_blocks):
140+
x = block(x, mask=mask)
141+
x = self.siglip_encoder.final_norm(x)
142+
143+
# Project the image embeddings to text hidden size.
133144
x = self.connector(x)
134145
return x
135146

@@ -166,7 +177,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
166177
output_proj_use_bias=True,
167178
)
168179
norm_config = cfg.NormalizationConfig(
169-
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
180+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6,
170181
)
171182
ff_config = cfg.FeedForwardConfig(
172183
type=cfg.FeedForwardType.SEQUENTIAL,
@@ -189,15 +200,13 @@ def get_image_encoder_config() -> cfg.ModelConfig:
189200
image_embedding=image_embedding_config,
190201
block_configs=block_config,
191202
final_norm_config=norm_config,
192-
# num_mm_tokens_per_image=81,
193-
# enable_hlfb=False
194203
)
195204
return config
196205

197206

198207
def build_image_encoder(
199208
checkpoint_path: str,
200-
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
209+
custom_loader: Optional[Callable[[str], Dict[str, torch.Tensor]]] = None,
201210
) -> FullVisionEncoder:
202211
"""Builds a FullVisionEncoder from the checkpoint path."""
203212
encoder_config = get_image_encoder_config()
@@ -208,7 +217,6 @@ def build_image_encoder(
208217
)
209218
loader.load(encoder.siglip_encoder, strict=False)
210219

211-
loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader)
212220
state = loader.get_state()
213221
converted_state = dict()
214222
converted_state["modality_projection.weight"] = state.pop(

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x):
180180
return torch.ops.tfl.rsqrt(x)
181181

182182

183+
@register_decomp(torch.ops.aten.neg.default)
184+
def _aten_neg_decomp(x):
185+
return torch.ops.tfl.neg(x)
186+
187+
183188
@register_decomp(torch.ops.aten.gelu.default)
184189
def _aten_gelu_decomp(x, approximate="none"):
185190
return torch.ops.tfl.gelu(x, approximate != "none")
@@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index):
317322
return torch.ops.tfl.squeeze(sliced, [dim])
318323

319324

325+
@register_decomp(torch.ops.aten.slice.Tensor)
326+
def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
327+
rank = x.dim()
328+
dim_size = x.shape[dim]
329+
330+
# Initialize begin, end, strides for tfl.strided_slice
331+
begin = [0] * rank
332+
end_vec = list(x.shape)
333+
strides = [1] * rank
334+
335+
# The logic below is to match PyTorch's `slice` behavior.
336+
# `start` and `end` can be negative, which means they count from the end.
337+
# `start=None` defaults to 0.
338+
# `end=None` or a large number defaults to `dim_size` after clamping.
339+
340+
start_val = 0 if start is None else start
341+
if start_val < 0:
342+
start_val += dim_size
343+
344+
end_val = dim_size if end is None else end
345+
if end_val < 0:
346+
end_val += dim_size
347+
348+
# Clamp start and end to be within the dimension size, following PyTorch's
349+
# logic.
350+
start_val = max(0, min(start_val, dim_size))
351+
end_val = max(start_val, min(end_val, dim_size))
352+
353+
begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
354+
return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)
355+
356+
320357
@register_decomp(torch.ops.aten.where.self)
321358
def _aten_where_self_decomp(condition, x, y):
322359
x, y = _promote_types_for_binary_op(x, y)
@@ -351,3 +388,27 @@ def _aten__softmax_decomp(
351388
softmax_result = torch.ops.tfl.softmax(x_permuted)
352389
# Transpose the result back to the original dimensions.
353390
return torch.ops.tfl.transpose(softmax_result, dims)
391+
392+
393+
@register_decomp(torch.ops.aten.topk.default)
394+
def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True):
395+
if not largest:
396+
raise ValueError("Only largest=True is supported for torch.topk.")
397+
398+
if dim < 0:
399+
dim = self.dim() + dim
400+
401+
if dim != self.dim() - 1:
402+
self = torch.transpose(self, dim, -1)
403+
404+
# Ignores sorted value: tfl.topk_v2 only supports sorted=True, but it doesn't
405+
# affect the correctness of the output.
406+
out, indices = torch.ops.tfl.topk_v2(self, k)
407+
408+
if dim != self.dim() - 1:
409+
out = torch.transpose(out, dim, -1)
410+
indices = torch.transpose(indices, dim, -1)
411+
412+
# torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32.
413+
indices = indices.to(torch.int64)
414+
return out, indices

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering(
286286
)
287287

288288

289+
@lower(torch.ops.tfl.neg.default)
290+
def _tfl_neg_lowering(
291+
lctx: LoweringContext,
292+
x: ir.Value,
293+
) -> ir.Value:
294+
return _ir_operation(
295+
"tfl.neg",
296+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
297+
operands=[x],
298+
)
299+
300+
289301
@lower(torch.ops.tfl.gelu.default)
290302
def _tfl_gelu_lowering(
291303
lctx: LoweringContext,
@@ -674,3 +686,20 @@ def _tfl_softmax_lowering(
674686
"beta": ir.FloatAttr.get(ir.F32Type.get(), beta),
675687
},
676688
)
689+
690+
691+
@lower(torch.ops.tfl.topk_v2.default)
692+
def _tfl_topk_v2_lowering(
693+
lctx: LoweringContext,
694+
x: ir.Value,
695+
k: int,
696+
) -> tuple[ir.Value, ir.Value]:
697+
return _ir_operation(
698+
"tfl.topk_v2",
699+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
700+
operands=[
701+
x,
702+
lowering_utils.numpy_array_constant(np.array(k, dtype=np.int32)),
703+
],
704+
attributes={},
705+
)

0 commit comments

Comments
 (0)