diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..810f8fa921 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_hub import export as export from keras_hub import layers as layers from keras_hub import metrics as metrics from keras_hub import models as models diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py new file mode 100644 index 0000000000..32a373f4b5 --- /dev/null +++ b/keras_hub/api/export/__init__.py @@ -0,0 +1,36 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_hub.src.export.configs import ( + AudioToTextExporterConfig as AudioToTextExporterConfig, +) +from keras_hub.src.export.configs import ( + CausalLMExporterConfig as CausalLMExporterConfig, +) +from keras_hub.src.export.configs import ( + DepthEstimatorExporterConfig as DepthEstimatorExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageClassifierExporterConfig as ImageClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageSegmenterExporterConfig as ImageSegmenterExporterConfig, +) +from keras_hub.src.export.configs import ( + ObjectDetectorExporterConfig as ObjectDetectorExporterConfig, +) +from keras_hub.src.export.configs import ( + SAMImageSegmenterExporterConfig as SAMImageSegmenterExporterConfig, +) +from keras_hub.src.export.configs import ( + Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, +) +from keras_hub.src.export.configs import ( + TextClassifierExporterConfig as TextClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + TextToImageExporterConfig as TextToImageExporterConfig, +) diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py new file mode 100644 index 0000000000..06e51a2411 --- /dev/null +++ b/keras_hub/src/export/__init__.py @@ -0,0 +1,12 @@ +# Export configurations and convenience functions +from keras_hub.src.export.configs import AudioToTextExporterConfig +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import ImageClassifierExporterConfig +from keras_hub.src.export.configs import ImageSegmenterExporterConfig +from keras_hub.src.export.configs import KerasHubExporterConfig +from keras_hub.src.export.configs import ObjectDetectorExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import TextToImageExporterConfig +from keras_hub.src.export.configs import get_exporter_config +from keras_hub.src.export.litert import export_litert diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py new file mode 100644 index 0000000000..d168ecf0ab --- /dev/null +++ b/keras_hub/src/export/configs.py @@ -0,0 +1,996 @@ +"""Configuration classes for different Keras-Hub model types. + +This module provides specific configurations for exporting different types +of Keras-Hub models. Each configuration knows how to generate the appropriate +input signature for its model type, which is then used by Keras Core's export. +""" + +from abc import ABC +from abc import abstractmethod + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.audio_to_text import AudioToText +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.depth_estimator import DepthEstimator +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.text_classifier import TextClassifier +from keras_hub.src.models.text_to_image import TextToImage + + +class KerasHubExporterConfig(ABC): + """Base configuration class for Keras-Hub model exporters. + + This class defines the interface for exporter configurations that specify + how different types of Keras-Hub models should be exported. Each subclass + provides domain-specific knowledge about input signatures for its model + type. + """ + + # Model type this exporter handles (e.g., "causal_lm", "text_classifier") + MODEL_TYPE = None + + # Expected input structure for this model type + EXPECTED_INPUTS = [] + + def __init__(self, model, **kwargs): + """Initialize the exporter configuration. + + Args: + model: `keras.Model`. The Keras-Hub model to export. + **kwargs: Additional configuration parameters. + """ + self.model = model + self.config_kwargs = kwargs + self._validate_model() + + def _validate_model(self): + """Validate that the model is compatible with this exporter.""" + if not self._is_model_compatible(): + raise ValueError( + f"Model {self.model.__class__.__name__} is not compatible " + f"with {self.__class__.__name__}" + ) + + @abstractmethod + def _is_model_compatible(self): + """Check if the model is compatible with this exporter. + + Returns: + `bool`. True if compatible, False otherwise + """ + pass + + @abstractmethod + def get_input_signature(self, sequence_length=None): + """Get the input signature for this model type. + + Args: + sequence_length: `int` or `None`. Optional sequence length for + input tensors. + + Returns: + `dict`. Dictionary mapping input names to tensor specifications. + """ + pass + + +def _get_text_input_signature(model, sequence_length=None): + """Get input signature for text models with token_ids and padding_mask. + + Args: + model: The model instance. + sequence_length: `int` or `None`. Sequence length. If None, uses + dynamic shape to support variable-length inputs via + resize_tensor_input at runtime. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + return { + "token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "padding_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + } + + +def _get_seq2seq_input_signature(model, sequence_length=None): + """Get input signature for seq2seq models with encoder/decoder tokens. + + Args: + model: The model instance. + sequence_length: `int` or `None`. Sequence length. If None, uses + dynamic shape to support variable-length inputs via + resize_tensor_input at runtime. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + return { + "encoder_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "encoder_padding_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "decoder_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "decoder_padding_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + } + + +def _infer_image_size(model): + """Infer image size from model preprocessor or inputs. + + Args: + model: The model instance. + + Returns: + `tuple`. Image size as (height, width). + + Raises: + ValueError: If image_size cannot be determined. + """ + image_size = None + + # Get from preprocessor + if hasattr(model, "preprocessor") and model.preprocessor: + if hasattr(model.preprocessor, "image_size"): + image_size = model.preprocessor.image_size + + # Try to infer from model inputs + if image_size is None and hasattr(model, "inputs") and model.inputs: + input_shape = model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if image_size is None: + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size " + "attribute, or model inputs should have concrete shapes." + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + return image_size + + +def _infer_image_dtype(model): + """Infer image dtype from model inputs. + + Args: + model: The model instance. + + Returns: + `str`. Image dtype (defaults to "float32"). + """ + if hasattr(model, "inputs") and model.inputs: + model_dtype = model.inputs[0].dtype + return model_dtype.name if hasattr(model_dtype, "name") else model_dtype + return "float32" + + +@keras_hub_export("keras_hub.export.CausalLMExporterConfig") +class CausalLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" + + MODEL_TYPE = "causal_lm" + + def __init__(self, model): + super().__init__(model) + # Determine expected inputs based on whether model is multimodal + # Check for Gemma3-style vision encoder + if ( + hasattr(model, "backbone") + and hasattr(model.backbone, "vision_encoder") + and model.backbone.vision_encoder is not None + ): + self.EXPECTED_INPUTS = [ + "token_ids", + "padding_mask", + "images", + "vision_mask", + "vision_indices", + ] + # Check for PaliGemma-style multimodal (has image_encoder or + # vit attributes) + elif self._is_paligemma_style_multimodal(model): + self.EXPECTED_INPUTS = [ + "token_ids", + "padding_mask", + "images", + "response_mask", + ] + # Check for Parseq-style vision (has image_encoder in backbone) + elif self._is_parseq_style_vision(model): + self.EXPECTED_INPUTS = ["token_ids", "padding_mask", "images"] + else: + self.EXPECTED_INPUTS = ["token_ids", "padding_mask"] + + def _is_paligemma_style_multimodal(self, model): + """Check if model is PaliGemma-style multimodal (vision + language).""" + if hasattr(model, "backbone"): + backbone = model.backbone + # PaliGemma has vit parameters or image-related attributes + if hasattr(backbone, "image_size") and ( + hasattr(backbone, "vit_num_layers") + or hasattr(backbone, "vit_patch_size") + ): + return True + return False + + def _is_parseq_style_vision(self, model): + """Check if model is Parseq-style vision model (OCR causal LM).""" + if hasattr(model, "backbone"): + backbone = model.backbone + # Parseq has an image_encoder attribute + if hasattr(backbone, "image_encoder"): + return True + return False + + def _is_model_compatible(self): + """Check if model is a causal language model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, CausalLM) + + def get_input_signature(self, sequence_length=None): + """Get input signature for causal LM models. + + Args: + sequence_length: `int`, `None`, or `dict`. Optional sequence length. + If None, uses preprocessor's sequence_length if available, + otherwise exports with dynamic shape for flexibility. If dict, + should contain 'sequence_length' and 'image_size' for + multimodal models. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + # If no sequence_length provided, try to get it from preprocessor + if ( + sequence_length is None + and hasattr(self.model, "preprocessor") + and self.model.preprocessor is not None + ): + if hasattr(self.model.preprocessor, "sequence_length"): + sequence_length = self.model.preprocessor.sequence_length + elif hasattr(self.model.preprocessor, "max_sequence_length"): + sequence_length = self.model.preprocessor.max_sequence_length + + # Use dynamic shape (None) by default for TFLite flexibility + # Users can resize at runtime via interpreter.resize_tensor_input() + + # Handle dict param for multimodal models + if isinstance(sequence_length, dict): + seq_len = sequence_length.get("sequence_length", None) + else: + seq_len = sequence_length + + signature = _get_text_input_signature(self.model, seq_len) + + # Check if Gemma3-style multimodal (vision encoder) + if ( + hasattr(self.model.backbone, "vision_encoder") + and self.model.backbone.vision_encoder is not None + ): + # Add Gemma3 vision inputs + if isinstance(sequence_length, dict): + image_size = sequence_length.get("image_size", None) + if image_size is not None and isinstance(image_size, tuple): + image_size = image_size[0] # Use first dimension if tuple + else: + image_size = getattr(self.model.backbone, "image_size", 224) + + if image_size is None: + image_size = getattr(self.model.backbone, "image_size", 224) + + signature.update( + { + "images": keras.layers.InputSpec( + dtype="float32", + shape=(None, None, image_size, image_size, 3), + ), + "vision_mask": keras.layers.InputSpec( + dtype="int32", # Use int32 instead of bool for + # TFLite compatibility + shape=(None, None), + ), + "vision_indices": keras.layers.InputSpec( + dtype="int32", shape=(None, None) + ), + } + ) + # Check if PaliGemma-style multimodal + elif self._is_paligemma_style_multimodal(self.model): + # Get image size from backbone + image_size = getattr(self.model.backbone, "image_size", 224) + if isinstance(sequence_length, dict): + image_size = sequence_length.get("image_size", image_size) + + # Handle tuple image_size (height, width) + if isinstance(image_size, tuple): + image_height, image_width = image_size[0], image_size[1] + else: + image_height, image_width = image_size, image_size + + signature.update( + { + "images": keras.layers.InputSpec( + dtype="float32", + shape=(None, image_height, image_width, 3), + ), + "response_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, seq_len) + ), + } + ) + # Check if Parseq-style vision + elif self._is_parseq_style_vision(self.model): + # Get image size from backbone's image_encoder + if hasattr(self.model.backbone, "image_encoder") and hasattr( + self.model.backbone.image_encoder, "image_shape" + ): + image_shape = self.model.backbone.image_encoder.image_shape + image_height, image_width = image_shape[0], image_shape[1] + else: + image_height, image_width = 32, 128 # Default for Parseq + + if isinstance(sequence_length, dict): + image_height = sequence_length.get("image_height", image_height) + image_width = sequence_length.get("image_width", image_width) + + signature.update( + { + "images": keras.layers.InputSpec( + dtype="float32", + shape=(None, image_height, image_width, 3), + ), + } + ) + + return signature + + +@keras_hub_export("keras_hub.export.TextClassifierExporterConfig") +class TextClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Text Classification models.""" + + MODEL_TYPE = "text_classifier" + + def __init__(self, model): + super().__init__(model) + # Determine expected inputs based on model characteristics + inputs = ["token_ids"] + + if self._model_uses_padding_mask(): + inputs.append("padding_mask") + + if self._model_uses_segment_ids(): + inputs.append("segment_ids") + + self.EXPECTED_INPUTS = inputs + + def _model_uses_segment_ids(self): + """Check if the model expects segment_ids input. + + Returns: + bool: True if model uses segment_ids, False otherwise + """ + # Check if model has a backbone with num_segments attribute + if hasattr(self.model, "backbone"): + backbone = self.model.backbone + # RoformerV2 and similar models have num_segments + if hasattr(backbone, "num_segments"): + return True + return False + + def _model_uses_padding_mask(self): + """Check if the model expects padding_mask input. + + Returns: + bool: True if model uses padding_mask, False otherwise + """ + # RoformerV2 doesn't use padding_mask in its preprocessor + # Check the model's backbone type + if hasattr(self.model, "backbone"): + backbone_class_name = self.model.backbone.__class__.__name__ + # RoformerV2 doesn't use padding_mask + if "RoformerV2" in backbone_class_name: + return False + # ESM computes attention mask internally from token_ids + if "ESM" in backbone_class_name: + return False + return True + + def _is_model_compatible(self): + """Check if model is a text classifier. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, TextClassifier) + + def get_input_signature(self, sequence_length=None): + """Get input signature for text classifier models. + + Args: + sequence_length: `int` or `None`. Optional sequence length. If None, + uses preprocessor's sequence_length if available, otherwise + exports with dynamic shape for flexibility. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + # If no sequence_length provided, try to get it from preprocessor + if ( + sequence_length is None + and hasattr(self.model, "preprocessor") + and self.model.preprocessor is not None + ): + if hasattr(self.model.preprocessor, "sequence_length"): + sequence_length = self.model.preprocessor.sequence_length + elif hasattr(self.model.preprocessor, "max_sequence_length"): + sequence_length = self.model.preprocessor.max_sequence_length + + # Use dynamic shape (None) by default for TFLite flexibility + # Users can resize at runtime via interpreter.resize_tensor_input() + signature = { + "token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ) + } + + # Add padding_mask if needed + if self._model_uses_padding_mask(): + signature["padding_mask"] = keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ) + + # Add segment_ids if needed + if self._model_uses_segment_ids(): + signature["segment_ids"] = keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ) + + return signature + + +@keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") +class Seq2SeqLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Sequence-to-Sequence Language Models.""" + + MODEL_TYPE = "seq2seq_lm" + EXPECTED_INPUTS = [ + "encoder_token_ids", + "encoder_padding_mask", + "decoder_token_ids", + "decoder_padding_mask", + ] + + def _is_model_compatible(self): + """Check if model is a seq2seq language model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, Seq2SeqLM) + + def get_input_signature(self, sequence_length=None): + """Get input signature for seq2seq models. + + Args: + sequence_length: `int` or `None`. Optional sequence length. If None, + uses preprocessor's sequence_length if available, otherwise + exports with dynamic shape for flexibility. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + # If no sequence_length provided, try to get it from preprocessor + if ( + sequence_length is None + and hasattr(self.model, "preprocessor") + and self.model.preprocessor is not None + ): + if hasattr(self.model.preprocessor, "sequence_length"): + sequence_length = self.model.preprocessor.sequence_length + elif hasattr(self.model.preprocessor, "max_sequence_length"): + sequence_length = self.model.preprocessor.max_sequence_length + + # Use dynamic shape (None) by default for TFLite flexibility + # Users can resize at runtime via interpreter.resize_tensor_input() + return _get_seq2seq_input_signature(self.model, sequence_length) + + +@keras_hub_export("keras_hub.export.AudioToTextExporterConfig") +class AudioToTextExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Audio-to-Text models. + + AudioToText models process audio input and generate text output, + such as speech recognition or audio transcription models. + """ + + MODEL_TYPE = "audio_to_text" + EXPECTED_INPUTS = [ + "encoder_input_values", # Audio features + "encoder_padding_mask", + "decoder_token_ids", + "decoder_padding_mask", + ] + + def _is_model_compatible(self): + """Check if model is an audio-to-text model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, AudioToText) + + def get_input_signature(self, sequence_length=None, audio_length=None): + """Get input signature for audio-to-text models. + + Args: + sequence_length: `int` or `None`. Optional text sequence length. + If None, exports with dynamic shape for flexibility. + audio_length: `int` or `None`. Optional audio sequence length. + If None, exports with dynamic shape for flexibility. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + # Audio features come from the audio encoder + # Text tokens go to the decoder + return { + "encoder_input_values": keras.layers.InputSpec( + dtype="float32", shape=(None, audio_length) + ), + "encoder_padding_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, audio_length) + ), + "decoder_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "decoder_padding_mask": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + } + + +@keras_hub_export("keras_hub.export.ImageClassifierExporterConfig") +class ImageClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Classification models.""" + + MODEL_TYPE = "image_classifier" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image classifier. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ImageClassifier) + + def get_input_signature(self, image_size=None): + """Get input signature for image classifier models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + Single `InputSpec` for the images input (not a dict, since + ImageClassifier models expect a single tensor, not dict inputs). + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + # Return single InputSpec (not dict) for single-input models + return keras.layers.InputSpec(dtype=dtype, shape=(None, *image_size, 3)) + + +@keras_hub_export("keras_hub.export.ObjectDetectorExporterConfig") +class ObjectDetectorExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Object Detection models.""" + + MODEL_TYPE = "object_detector" + EXPECTED_INPUTS = ["images"] # ObjectDetector models only take images + + def _is_model_compatible(self): + """Check if model is an object detector. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ObjectDetector) + + def get_input_signature(self, image_size=None): + """Get input signature for object detector models. + + Note: ObjectDetector models only take 'images' as input, + not 'image_shape'. The image_shape parameter is used to determine + the input dimensions. + + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + Single InputSpec for images (not a dict, as there's only one input) + """ + if image_size is None: + # Try to infer from preprocessor, but fall back to dynamic shapes + # for object detectors which support variable input sizes + try: + image_size = _infer_image_size(self.model) + except ValueError: + # If cannot infer, use dynamic shapes + image_size = None + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + if image_size is not None: + # Use concrete shapes when image_size is available + return keras.layers.InputSpec( + dtype=dtype, shape=(None, *image_size, 3) + ) + else: + # Use dynamic shapes for variable input sizes + return keras.layers.InputSpec( + dtype=dtype, shape=(None, None, None, 3) + ) + + +@keras_hub_export("keras_hub.export.ImageSegmenterExporterConfig") +class ImageSegmenterExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Segmentation models.""" + + MODEL_TYPE = "image_segmenter" + EXPECTED_INPUTS = [ + "inputs" + ] # ImageSegmenter models use 'inputs' not 'images' + + def _is_model_compatible(self): + """Check if model is an image segmenter. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ImageSegmenter) + + def get_input_signature(self, image_size=None): + """Get input signature for image segmenter models. + + Note: ImageSegmenter models use 'inputs' as the input name, + not 'images'. + + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + Single InputSpec for inputs (not a dict, as there's only one input) + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + return keras.layers.InputSpec( + dtype=dtype, shape=(None, *image_size, 3), name="inputs" + ) + + +@keras_hub_export("keras_hub.export.SAMImageSegmenterExporterConfig") +class SAMImageSegmenterExporterConfig(KerasHubExporterConfig): + """Exporter configuration for SAM (Segment Anything Model). + + SAM requires multiple prompt inputs (points, boxes, masks) in addition + to images. For TFLite/LiteRT export, we use fixed shapes to avoid issues + with 0-sized dimensions in the XNNPack delegate. + + Mobile SAM implementations typically use fixed shapes: + - 1 point prompt (padded with zeros if not used) + - 1 box prompt (padded with zeros if not used) + - 1 mask prompt (zero-filled means "no mask") + """ + + MODEL_TYPE = "image_segmenter" + EXPECTED_INPUTS = ["images", "points", "labels", "boxes", "masks"] + + def _is_model_compatible(self): + """Check if model is a SAM image segmenter. + Returns: + `bool`. True if compatible, False otherwise + """ + if not isinstance(self.model, ImageSegmenter): + return False + # Check if backbone is SAM - must have SAM in backbone class name + if hasattr(self.model, "backbone"): + backbone_class_name = self.model.backbone.__class__.__name__ + # Only SAM models should use this config + if "SAM" in backbone_class_name.upper(): + return True + return False + + def get_input_signature(self, image_size=None): + """Get input signature for SAM models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + # For SAM, mask inputs should be at 4 * image_embedding_size resolution + # image_embedding_size is typically image_size // 16 for patch_size=16 + image_embedding_size = (image_size[0] // 16, image_size[1] // 16) + mask_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + + return { + "images": keras.layers.InputSpec( + dtype=dtype, shape=(None, *image_size, 3) + ), + "points": keras.layers.InputSpec( + dtype="float32", + shape=(None, 1, 2), # Fixed: 1 point + ), + "labels": keras.layers.InputSpec( + dtype="float32", + shape=(None, 1), # Fixed: 1 label + ), + "boxes": keras.layers.InputSpec( + dtype="float32", + shape=(None, 1, 2, 2), # Fixed: 1 box + ), + "masks": keras.layers.InputSpec( + dtype="float32", + shape=( + None, + 1, + *mask_size, + 1, + ), # Fixed: 1 mask at correct resolution + ), + } + + +@keras_hub_export("keras_hub.export.DepthEstimatorExporterConfig") +class DepthEstimatorExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Depth Estimation models.""" + + MODEL_TYPE = "depth_estimator" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is a depth estimator. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, DepthEstimator) + + def get_input_signature(self, image_size=None): + """Get input signature for depth estimation models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + return { + "images": keras.layers.InputSpec( + dtype=dtype, shape=(None, *image_size, 3) + ), + } + + +@keras_hub_export("keras_hub.export.TextToImageExporterConfig") +class TextToImageExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Text-to-Image models. + + TextToImage models generate images from text prompts, + such as Stable Diffusion, DALL-E, or similar generative models. + """ + + MODEL_TYPE = "text_to_image" + EXPECTED_INPUTS = [ + "images", + "latents", + "clip_l_token_ids", + "clip_l_negative_token_ids", + "clip_g_token_ids", + "clip_g_negative_token_ids", + "num_steps", + "guidance_scale", + ] + + def _is_model_compatible(self): + """Check if model is a text-to-image model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, TextToImage) + + def _is_stable_diffusion_3(self): + """Check if model is Stable Diffusion 3. + + Returns: + `bool`. True if model is SD3, False otherwise + """ + return "StableDiffusion3" in self.model.__class__.__name__ + + def get_input_signature( + self, sequence_length=None, image_size=None, latent_shape=None + ): + """Get input signature for text-to-image models. + + Args: + sequence_length: `int` or `None`. Optional text sequence length. + If None, exports with dynamic shape for flexibility. + image_size: `tuple`, `int` or `None`. Optional image size. If None, + infers from model. + latent_shape: `tuple` or `None`. Optional latent shape. If None, + infers from model. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + # Check if this is Stable Diffusion 3 which has dual CLIP encoders + if self._is_stable_diffusion_3(): + # Get image size from backbone if available + if image_size is None: + if hasattr(self.model, "backbone") and hasattr( + self.model.backbone, "image_shape" + ): + image_shape_tuple = self.model.backbone.image_shape + image_size = (image_shape_tuple[0], image_shape_tuple[1]) + else: + # Try to infer from inputs + if hasattr(self.model, "input") and isinstance( + self.model.input, dict + ): + if "images" in self.model.input: + img_shape = self.model.input["images"].shape + if ( + img_shape[1] is not None + and img_shape[2] is not None + ): + image_size = (img_shape[1], img_shape[2]) + if image_size is None: + raise ValueError( + "Could not determine image size for " + "StableDiffusion3. " + "Please provide image_size parameter." + ) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + # Get latent shape from backbone if available + if latent_shape is None: + if hasattr(self.model, "backbone") and hasattr( + self.model.backbone, "latent_shape" + ): + latent_shape_tuple = self.model.backbone.latent_shape + # latent_shape is (None, h, w, c), we need (h, w, c) + if latent_shape_tuple[0] is None: + latent_shape = latent_shape_tuple[1:] + else: + latent_shape = latent_shape_tuple + else: + # Default latent shape for SD3 (typically 1/8 of image size + # with 16 channels) + latent_shape = (image_size[0] // 8, image_size[1] // 8, 16) + + return { + "images": keras.layers.InputSpec( + dtype="float32", shape=(None, *image_size, 3) + ), + "latents": keras.layers.InputSpec( + dtype="float32", shape=(None, *latent_shape) + ), + "clip_l_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "clip_l_negative_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "clip_g_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "clip_g_negative_token_ids": keras.layers.InputSpec( + dtype="int32", shape=(None, sequence_length) + ), + "num_steps": keras.layers.InputSpec( + dtype="int32", shape=(None,) + ), + "guidance_scale": keras.layers.InputSpec( + dtype="float32", shape=(None,) + ), + } + else: + # For other text-to-image models, use simple text inputs + return _get_text_input_signature(self.model, sequence_length) + + +def get_exporter_config(model): + """Get the appropriate exporter configuration for a model instance. + + This function automatically detects the model type and returns the + corresponding exporter configuration. + + Args: + model: A Keras-Hub model instance (e.g., CausalLM, TextClassifier). + + Returns: + An instance of the appropriate KerasHubExporterConfig subclass. + + Raises: + ValueError: If the model type is not supported for export. + """ + # Mapping of model classes to their config classes + # NOTE: Order matters! More specific configs must be checked first: + # - AudioToText before Seq2SeqLM (AudioToText is a subclass of Seq2SeqLM) + # - Seq2SeqLM before CausalLM (Seq2SeqLM is a subclass of CausalLM) + # - SAMImageSegmenterExporterConfig before ImageSegmenterExporterConfig + _MODEL_TYPE_TO_CONFIG = [ + (AudioToText, AudioToTextExporterConfig), + (Seq2SeqLM, Seq2SeqLMExporterConfig), + (CausalLM, CausalLMExporterConfig), + (TextClassifier, TextClassifierExporterConfig), + (ImageClassifier, ImageClassifierExporterConfig), + (ObjectDetector, ObjectDetectorExporterConfig), + (ImageSegmenter, SAMImageSegmenterExporterConfig), # Check SAM first + (ImageSegmenter, ImageSegmenterExporterConfig), # Then generic + (DepthEstimator, DepthEstimatorExporterConfig), + (TextToImage, TextToImageExporterConfig), + ] + + # Find matching config class + for model_class, config_class in _MODEL_TYPE_TO_CONFIG: + if isinstance(model, model_class): + # Try to create config and check compatibility + try: + config = config_class(model) + return config + except ValueError: + # Model not compatible with this config, try next one + continue + + # Model type not supported + supported_types = ", ".join( + set(cls.__name__ for cls, _ in _MODEL_TYPE_TO_CONFIG) + ) + raise ValueError( + f"Could not find exporter config for model type " + f"'{model.__class__.__name__}'. " + f"Supported types: {supported_types}" + ) diff --git a/keras_hub/src/export/configs_test.py b/keras_hub/src/export/configs_test.py new file mode 100644 index 0000000000..5d3ce68e7c --- /dev/null +++ b/keras_hub/src/export/configs_test.py @@ -0,0 +1,302 @@ +"""Tests for export configuration classes.""" + +import keras + +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import ImageClassifierExporterConfig +from keras_hub.src.export.configs import ImageSegmenterExporterConfig +from keras_hub.src.export.configs import ObjectDetectorExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.tests.test_case import TestCase + + +class MockPreprocessor: + """Mock preprocessor for testing.""" + + def __init__(self, sequence_length=None, image_size=None): + if sequence_length is not None: + self.sequence_length = sequence_length + if image_size is not None: + self.image_size = image_size + + +class MockCausalLM(keras.Model): + """Mock Causal LM model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["token_ids"]) + + +class MockTextClassifier(keras.Model): + """Mock Text Classifier model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(5) + + def call(self, inputs): + return self.dense(inputs["token_ids"]) + + +class MockImageClassifier(keras.Model): + """Mock Image Classifier model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + return self.dense(inputs) + + +class CausalLMExporterConfigTest(TestCase): + """Tests for CausalLMExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "causal_lm") + self.assertEqual(config.EXPECTED_INPUTS, ["token_ids", "padding_mask"]) + + def test_get_input_signature_default(self): + """Test get_input_signature with dynamic shape (default).""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("token_ids", signature) + self.assertIn("padding_mask", signature) + # Default is now dynamic shape (None) for flexibility + self.assertEqual(signature["token_ids"].shape, (None, None)) + self.assertEqual(signature["padding_mask"].shape, (None, None)) + + def test_get_input_signature_from_preprocessor(self): + """Test get_input_signature uses preprocessor's sequence_length by + default.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(sequence_length=256) + model = MockCausalLMForTest(preprocessor) + config = CausalLMExporterConfig(model) + # Without explicit sequence_length parameter, uses preprocessor's + # sequence_length + signature = config.get_input_signature() + + # Should use preprocessor's sequence_length by default + self.assertEqual(signature["token_ids"].shape, (None, 256)) + self.assertEqual(signature["padding_mask"].shape, (None, 256)) + + def test_get_input_signature_dynamic_when_no_preprocessor(self): + """Test get_input_signature uses dynamic shape when no preprocessor.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + # Without preprocessor, uses dynamic shape + signature = config.get_input_signature() + + # Should use dynamic shape when no preprocessor available + self.assertEqual(signature["token_ids"].shape, (None, None)) + self.assertEqual(signature["padding_mask"].shape, (None, None)) + + def test_get_input_signature_custom_length(self): + """Test get_input_signature with custom sequence length.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + signature = config.get_input_signature(sequence_length=512) + + # Should use provided sequence length + self.assertEqual(signature["token_ids"].shape, (None, 512)) + self.assertEqual(signature["padding_mask"].shape, (None, 512)) + + +class TextClassifierExporterConfigTest(TestCase): + """Tests for TextClassifierExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.text_classifier import TextClassifier + + class MockTextClassifierForTest(TextClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockTextClassifierForTest() + config = TextClassifierExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "text_classifier") + self.assertEqual(config.EXPECTED_INPUTS, ["token_ids", "padding_mask"]) + + def test_get_input_signature_default(self): + """Test get_input_signature with dynamic shape (default).""" + from keras_hub.src.models.text_classifier import TextClassifier + + class MockTextClassifierForTest(TextClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockTextClassifierForTest() + config = TextClassifierExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("token_ids", signature) + self.assertIn("padding_mask", signature) + # Default is now dynamic shape (None) for flexibility + self.assertEqual(signature["token_ids"].shape, (None, None)) + + +class ImageClassifierExporterConfigTest(TestCase): + """Tests for ImageClassifierExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.image_classifier import ImageClassifier + + class MockImageClassifierForTest(ImageClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockImageClassifierForTest() + config = ImageClassifierExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "image_classifier") + self.assertEqual(config.EXPECTED_INPUTS, ["images"]) + + def test_get_input_signature_with_preprocessor(self): + """Test get_input_signature infers image size from preprocessor.""" + from keras_hub.src.models.image_classifier import ImageClassifier + + class MockImageClassifierForTest(ImageClassifier): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(image_size=(224, 224)) + model = MockImageClassifierForTest(preprocessor) + config = ImageClassifierExporterConfig(model) + signature = config.get_input_signature() + + # ImageClassifier returns single InputSpec (not dict) + self.assertIsInstance(signature, keras.layers.InputSpec) + # Image shape should be (batch, height, width, channels) + expected_shape = (None, 224, 224, 3) + self.assertEqual(signature.shape, expected_shape) + + +class Seq2SeqLMExporterConfigTest(TestCase): + """Tests for Seq2SeqLMExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + + class MockSeq2SeqLMForTest(Seq2SeqLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockSeq2SeqLMForTest() + config = Seq2SeqLMExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "seq2seq_lm") + # Seq2Seq models have both encoder and decoder inputs + self.assertIn("encoder_token_ids", config.EXPECTED_INPUTS) + self.assertIn("decoder_token_ids", config.EXPECTED_INPUTS) + + +class ObjectDetectorExporterConfigTest(TestCase): + """Tests for ObjectDetectorExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.object_detector import ObjectDetector + + class MockObjectDetectorForTest(ObjectDetector): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockObjectDetectorForTest() + config = ObjectDetectorExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "object_detector") + # ObjectDetector only takes images input (not image_shape) + self.assertEqual(config.EXPECTED_INPUTS, ["images"]) + + def test_get_input_signature_with_preprocessor(self): + """Test get_input_signature infers from preprocessor.""" + from keras_hub.src.models.object_detector import ObjectDetector + + class MockObjectDetectorForTest(ObjectDetector): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(image_size=(512, 512)) + model = MockObjectDetectorForTest(preprocessor) + config = ObjectDetectorExporterConfig(model) + signature = config.get_input_signature() + + # ObjectDetector returns single InputSpec for images (not dict) + self.assertIsInstance(signature, keras.layers.InputSpec) + # Images shape should be (batch, height, width, channels) + self.assertEqual(signature.shape, (None, 512, 512, 3)) + + +class ImageSegmenterExporterConfigTest(TestCase): + """Tests for ImageSegmenterExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.image_segmenter import ImageSegmenter + + class MockImageSegmenterForTest(ImageSegmenter): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockImageSegmenterForTest() + config = ImageSegmenterExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "image_segmenter") + # ImageSegmenter uses 'inputs' not 'images' + self.assertEqual(config.EXPECTED_INPUTS, ["inputs"]) diff --git a/keras_hub/src/export/litert.py b/keras_hub/src/export/litert.py new file mode 100644 index 0000000000..6e6b59d242 --- /dev/null +++ b/keras_hub/src/export/litert.py @@ -0,0 +1,41 @@ +"""LiteRT exporter for Keras-Hub models. + +This module provides LiteRT export functionality specifically designed for +Keras-Hub models, handling their unique input structures and requirements. + +The exporter supports dynamic shape inputs by default, leveraging TFLite's +native capability to resize input tensors at runtime. When applicable parameters +are not specified, models are exported with flexible dimensions that can be +resized via `interpreter.resize_tensor_input()` before inference. +""" + + +# Convenience function for direct export +def export_litert(model, filepath, **kwargs): + """Export a Keras-Hub model to Litert format. + + This is a convenience function that automatically detects the model type + and exports it using the appropriate configuration. + + Args: + model: `keras.Model`. The Keras-Hub model to export. + filepath: `str`. Path where to save the model (without extension). + **kwargs: `dict`. Additional arguments passed to exporter. + """ + from keras.src.export.litert import export_litert as keras_export_litert + + from keras_hub.src.export.configs import get_exporter_config + + # Get the appropriate configuration for this model type + config = get_exporter_config(model) + + # Get domain-specific input signature from config + input_signature = config.get_input_signature() + + # Call Keras Core's export_litert directly + keras_export_litert( + model, + filepath, + input_signature=input_signature, + **kwargs, + ) diff --git a/keras_hub/src/export/litert_models_test.py b/keras_hub/src/export/litert_models_test.py new file mode 100644 index 0000000000..0d442676d3 --- /dev/null +++ b/keras_hub/src/export/litert_models_test.py @@ -0,0 +1,369 @@ +"""Tests for LiteRT export with specific production models. + +This test suite validates LiteRT export functionality for production +model presets including CausalLM, ImageClassifier, ObjectDetector, +and ImageSegmenter models. + +Each test validates export correctness by: +1. Loading a model from preset +2. Exporting it to LiteRT format +3. Running numerical verification to ensure exported model produces + equivalent outputs +4. Comparing outputs statistically against predefined thresholds + +This ensures that exported models maintain functional correctness and +numerical stability. +""" + +import gc + +import keras +import numpy as np +import pytest + +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.tests.test_case import TestCase + +# Model configurations for testing +CAUSAL_LM_MODELS = [ + { + "preset": "llama3.2_1b", + "model_class": Llama3CausalLM, + "sequence_length": 128, + "test_name": "llama3_2_1b", + "output_thresholds": {"*": {"max": 1e-3, "mean": 1e-5}}, + }, + { + "preset": "gemma3_1b", + "model_class": Gemma3CausalLM, + "sequence_length": 128, + "test_name": "gemma3_1b", + "output_thresholds": {"*": {"max": 1e-3, "mean": 3e-5}}, + }, + { + "preset": "gpt2_base_en", + "model_class": GPT2CausalLM, + "sequence_length": 128, + "test_name": "gpt2_base_en", + "output_thresholds": {"*": {"max": 5e-4, "mean": 5e-5}}, + }, +] + +IMAGE_CLASSIFIER_MODELS = [ + { + "preset": "resnet_50_imagenet", + "test_name": "resnet_50", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "efficientnet_b0_ra_imagenet", + "test_name": "efficientnet_b0", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "densenet_121_imagenet", + "test_name": "densenet_121", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "mobilenet_v3_small_100_imagenet", + "test_name": "mobilenet_v3_small", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, +] + +OBJECT_DETECTOR_MODELS = [ + { + "preset": "dfine_small_coco", + "test_name": "dfine_small", + "input_range": (0.0, 1.0), + "output_thresholds": { + "intermediate_predicted_corners": {"max": 5.0, "mean": 0.05}, + "intermediate_logits": {"max": 5.0, "mean": 0.1}, + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, + { + "preset": "dfine_medium_coco", + "test_name": "dfine_medium", + "input_range": (0.0, 1.0), + "output_thresholds": { + "intermediate_predicted_corners": {"max": 50.0, "mean": 0.15}, + "intermediate_logits": {"max": 5.0, "mean": 0.1}, + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 5.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, + { + "preset": "retinanet_resnet50_fpn_coco", + "test_name": "retinanet_resnet50", + "input_range": (0.0, 1.0), + "output_thresholds": { + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, +] + +IMAGE_SEGMENTER_MODELS = [ + { + "preset": "deeplab_v3_plus_resnet50_pascalvoc", + "test_name": "deeplab_v3_plus", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 1.0, "mean": 1e-2}}, + }, +] + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + CAUSAL_LM_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_causal_lm_litert_export(model_config): + """Test LiteRT export for CausalLM models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + model_class = model_config["model_class"] + sequence_length = model_config["sequence_length"] + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 3e-5, "mean": 3e-6}} + ) + + model = None + try: + # Load model from preset once + model = model_class.from_preset(preset, load_weights=True) + + # Set sequence length before export + model.preprocessor.sequence_length = sequence_length + + # Get vocab_size from the loaded model + vocab_size = model.backbone.vocabulary_size + + # Prepare test inputs with fixed random seed for reproducibility + np.random.seed(42) + input_data = { + "token_ids": np.random.randint( + 1, vocab_size, size=(1, sequence_length), dtype=np.int32 + ), + "padding_mask": np.ones((1, sequence_length), dtype=np.int32), + } + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=(1, sequence_length, vocab_size), + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + IMAGE_CLASSIFIER_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_image_classifier_litert_export(model_config): + """Test LiteRT export for ImageClassifier models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1e-4, "mean": 4e-5}} + ) + + model = None + try: + # Load model once + model = ImageClassifier.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + input_shape = image_size + (3,) # Add channels + + # Prepare test input + test_image = np.random.uniform( + input_range[0], input_range[1], size=(1,) + input_shape + ).astype(np.float32) + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_image, + expected_output_shape=None, # Output shape varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + OBJECT_DETECTOR_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_object_detector_litert_export(model_config): + """Test LiteRT export for ObjectDetector models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1.0, "mean": 0.02}} + ) + + model = None + try: + # Load model once + model = ObjectDetector.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + # ObjectDetector only needs images input (not image_shape) + test_inputs = np.random.uniform( + input_range[0], + input_range[1], + size=(1,) + image_size + (3,), + ).astype(np.float32) + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_inputs, + expected_output_shape=None, # Output varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + IMAGE_SEGMENTER_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_image_segmenter_litert_export(model_config): + """Test LiteRT export for ImageSegmenter models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1.0, "mean": 1e-2}} + ) + + model = None + try: + # Load model once + model = ImageSegmenter.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + input_shape = image_size + (3,) # Add channels + + # Prepare test input + test_image = np.random.uniform( + input_range[0], input_range[1], size=(1,) + input_shape + ).astype(np.float32) + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_image, + expected_output_shape=None, # Output shape varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() diff --git a/keras_hub/src/export/litert_test.py b/keras_hub/src/export/litert_test.py new file mode 100644 index 0000000000..c8cb111d2f --- /dev/null +++ b/keras_hub/src/export/litert_test.py @@ -0,0 +1,707 @@ +"""Tests for LiteRT export functionality.""" + +import os +import shutil +import tempfile + +import keras +import numpy as np +import pytest + +from keras_hub.src.export.litert import export_litert +from keras_hub.src.tests.test_case import TestCase + +# Lazy import LiteRT interpreter with fallback logic +LITERT_AVAILABLE = False +if keras.backend.backend() == "tensorflow": + try: + from ai_edge_litert.interpreter import Interpreter + + LITERT_AVAILABLE = True + except ImportError: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class LiteRTExportTest(TestCase): + """Tests for LiteRT export functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + # Clean up temporary files + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_litert_function_exists(self): + """Test that export_litert function is available.""" + # Simply test that the function can be imported and called + self.assertTrue(callable(export_litert)) + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class CausalLMExportTest(TestCase): + """Tests for exporting CausalLM models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_causal_lm_mock(self): + """Test exporting a mock CausalLM model with dynamic shape support.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM + class SimpleCausalLM(CausalLM): + def __init__(self): + super().__init__() + self.preprocessor = None + self.embedding = keras.layers.Embedding(1000, 64) + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + return self.dense(x) + + model = SimpleCausalLM() + model.build( + input_shape={ + "token_ids": (None, None), # Dynamic sequence length + "padding_mask": (None, None), + } + ) + + # Export using the model's export method with dynamic shapes + export_path = os.path.join(self.temp_dir, "test_causal_lm") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + + input_details = interpreter.get_input_details() + + # Verify that inputs support dynamic shapes (shape_signature has -1) + # This is the key improvement - TFLite now exports with dynamic shapes + for input_detail in input_details: + if "shape_signature" in input_detail: + # Check that sequence dimension is dynamic (-1) + self.assertEqual(input_detail["shape_signature"][1], -1) + + # Resize tensors to specific sequence length before allocating + # This demonstrates TFLite's dynamic shape support + seq_len = 128 + interpreter.resize_tensor_input(input_details[0]["index"], [1, seq_len]) + interpreter.resize_tensor_input(input_details[1]["index"], [1, seq_len]) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + output_details = interpreter.get_output_details() + + # Verify we have the expected inputs + self.assertEqual(len(input_details), 2) + + # Create test inputs with dtypes from the interpreter + test_token_ids = np.random.randint(0, 1000, (1, seq_len)).astype( + input_details[0]["dtype"] + ) + test_padding_mask = np.ones( + (1, seq_len), dtype=input_details[1]["dtype"] + ) + + # Set inputs and run inference + interpreter.set_tensor(input_details[0]["index"], test_token_ids) + interpreter.set_tensor(input_details[1]["index"], test_padding_mask) + interpreter.invoke() + + # Get output + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], seq_len) # Sequence length + self.assertEqual(output.shape[2], 1000) # Vocab size + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + def test_export_causal_lm_dynamic_shape_resize(self): + """Test exported CausalLM can resize inputs dynamically.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM + class SimpleCausalLM(CausalLM): + def __init__(self): + super().__init__() + self.preprocessor = None + self.embedding = keras.layers.Embedding(1000, 64) + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + return self.dense(x) + + model = SimpleCausalLM() + model.build( + input_shape={ + "token_ids": (None, None), + "padding_mask": (None, None), + } + ) + + # Export using dynamic shapes (no max_sequence_length specified) + export_path = os.path.join(self.temp_dir, "test_causal_lm_dynamic") + model.export(export_path, format="litert") + + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Test with different sequence lengths via resize_tensor_input + for seq_len in [32, 64, 128]: + interpreter = Interpreter(model_path=tflite_path) + + # Resize input tensors to desired sequence length + input_details = interpreter.get_input_details() + interpreter.resize_tensor_input( + input_details[0]["index"], [1, seq_len] + ) + interpreter.resize_tensor_input( + input_details[1]["index"], [1, seq_len] + ) + interpreter.allocate_tensors() + + # Create test inputs with the resized shape + test_token_ids = np.random.randint( + 0, 1000, (1, seq_len), dtype=input_details[0]["dtype"] + ) + test_padding_mask = np.ones( + (1, seq_len), dtype=input_details[1]["dtype"] + ) + + # Run inference + interpreter.set_tensor(input_details[0]["index"], test_token_ids) + interpreter.set_tensor(input_details[1]["index"], test_padding_mask) + interpreter.invoke() + + # Verify output shape matches input sequence length + output_details = interpreter.get_output_details() + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[1], seq_len) + + del interpreter + import gc + + gc.collect() + + # Clean up + if os.path.exists(tflite_path): + os.remove(tflite_path) + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ImageClassifierExportTest(TestCase): + """Tests for exporting ImageClassifier models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_image_classifier_mock(self): + """Test exporting a mock ImageClassifier model.""" + from keras_hub.src.models.backbone import Backbone + from keras_hub.src.models.image_classifier import ImageClassifier + + # Create a minimal mock Backbone + class SimpleBackbone(Backbone): + def __init__(self): + inputs = keras.layers.Input(shape=(224, 224, 3)) + x = keras.layers.Conv2D(32, 3, padding="same")(inputs) + # Don't reduce dimensions - let ImageClassifier handle pooling + outputs = x + super().__init__(inputs=inputs, outputs=outputs) + + # Create ImageClassifier with the mock backbone + backbone = SimpleBackbone() + model = ImageClassifier(backbone=backbone, num_classes=10) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "test_image_classifier") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Verify we have the expected input + self.assertEqual(len(input_details), 1) + + # Create test input with dtype from the interpreter + test_image = np.random.uniform(0.0, 1.0, (1, 224, 224, 3)).astype( + input_details[0]["dtype"] + ) + + # Set input and run inference + interpreter.set_tensor(input_details[0]["index"], test_image) + interpreter.invoke() + + # Get output + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], 10) # Number of classes + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + def test_signature_def_with_image_classifier(self): + """Test that SignatureDef preserves input names for + ImageClassifier models.""" + from keras_hub.src.models.backbone import Backbone + from keras_hub.src.models.image_classifier import ImageClassifier + + # Create a minimal mock Backbone with named input + class SimpleBackbone(Backbone): + def __init__(self): + inputs = keras.layers.Input( + shape=(224, 224, 3), name="image_input" + ) + x = keras.layers.Conv2D(32, 3, padding="same")(inputs) + outputs = x + super().__init__(inputs=inputs, outputs=outputs) + + # Create ImageClassifier with the mock backbone + backbone = SimpleBackbone() + model = ImageClassifier(backbone=backbone, num_classes=10) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "image_classifier_signature") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and check SignatureDef + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Get SignatureDef information + signature_defs = interpreter.get_signature_list() + self.assertIn("serving_default", signature_defs) + + serving_sig = signature_defs["serving_default"] + sig_inputs = serving_sig.get("inputs", []) + sig_outputs = serving_sig.get("outputs", []) + + # Verify SignatureDef has inputs and outputs + self.assertGreater( + len(sig_inputs), 0, "Should have at least one input in SignatureDef" + ) + self.assertGreater( + len(sig_outputs), + 0, + "Should have at least one output in SignatureDef", + ) + + # Verify that the named input is preserved in SignatureDef + # Note: ImageClassifier may use different input name, so we just verify + # that SignatureDef contains meaningful names, not generic ones + self.assertGreater( + len(sig_inputs), + 0, + f"Should have at least one input name in " + f"SignatureDef: {sig_inputs}", + ) + # sig_inputs is a list of input names + first_input_name = sig_inputs[0] if sig_inputs else "" + self.assertGreater( + len(first_input_name), + 0, + f"Input name should not be empty: {sig_inputs}", + ) + + # Verify inference works + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + test_image = np.random.uniform(0.0, 1.0, (1, 224, 224, 3)).astype( + input_details[0]["dtype"] + ) + + interpreter.set_tensor(input_details[0]["index"], test_image) + interpreter.invoke() + + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], 10) # Number of classes + + # Clean up + del interpreter + import gc + + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class TextClassifierExportTest(TestCase): + """Tests for exporting TextClassifier models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_text_classifier_mock(self): + """Test exporting a mock TextClassifier model.""" + from keras_hub.src.models.text_classifier import TextClassifier + + # Create a minimal mock TextClassifier + class SimpleTextClassifier(TextClassifier): + def __init__(self): + super().__init__() + self.preprocessor = None + self.embedding = keras.layers.Embedding(5000, 64) + self.pool = keras.layers.GlobalAveragePooling1D() + self.dense = keras.layers.Dense(5) # 5 classes + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + x = self.pool(x) + return self.dense(x) + + model = SimpleTextClassifier() + model.build( + input_shape={ + "token_ids": (None, 128), + "padding_mask": (None, 128), + } + ) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "test_text_classifier") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + output_details = interpreter.get_output_details() + + # Verify output shape (batch, num_classes) + self.assertEqual(len(output_details), 1) + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ExportNumericalVerificationTest(TestCase): + """Tests for numerical accuracy of exported models.""" + + def test_simple_model_numerical_accuracy(self): + """Test that exported model produces similar outputs to original.""" + # Create a simple sequential model with explicit Input layer + model = keras.Sequential( + [ + keras.layers.Input(shape=(5,)), + keras.layers.Dense(10, activation="relu"), + keras.layers.Dense(3, activation="softmax"), + ] + ) + + # Prepare test input + test_input = np.random.random((1, 5)).astype(np.float32) + + # Use standardized test from TestCase + # Note: This assumes the model has an export() method + # If not available, the test will be skipped + if not hasattr(model, "export"): + self.skipTest("model.export() not available") + + self.run_litert_export_test( + cls=keras.Sequential, + init_kwargs={ + "layers": [ + keras.layers.Input(shape=(5,)), + keras.layers.Dense(10, activation="relu"), + keras.layers.Dense(3, activation="softmax"), + ] + }, + input_data=test_input, + expected_output_shape=(1, 3), + comparison_mode="strict", + ) + + def test_dict_input_model_numerical_accuracy(self): + """Test numerical accuracy for models with dictionary inputs.""" + + # Define a custom model class for testing + class DictInputModel(keras.Model): + def __init__(self): + super().__init__() + self.concat = keras.layers.Concatenate() + self.dense = keras.layers.Dense(5) + + def call(self, inputs): + x = self.concat([inputs["input1"], inputs["input2"]]) + return self.dense(x) + + # Prepare test inputs + test_inputs = { + "input1": np.random.random((1, 10)).astype(np.float32), + "input2": np.random.random((1, 10)).astype(np.float32), + } + + # Use standardized test from TestCase + self.run_litert_export_test( + cls=DictInputModel, + init_kwargs={}, + input_data=test_inputs, + expected_output_shape=(1, 5), + comparison_mode="strict", + ) + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ExportErrorHandlingTest(TestCase): + """Tests for error handling in export process.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_to_invalid_path(self): + """Test that export with invalid path raises appropriate error.""" + if not hasattr(keras.Model, "export"): + self.skipTest("model.export() not available") + + model = keras.Sequential([keras.layers.Dense(10)]) + + # Try to export to a path that doesn't exist and can't be created + invalid_path = "/nonexistent/deeply/nested/path/model" + + with self.assertRaises(Exception): + model.export(invalid_path, format="litert") + + def test_export_unbuilt_model(self): + """Test exporting an unbuilt model.""" + if not hasattr(keras.Model, "export"): + self.skipTest("model.export() not available") + + model = keras.Sequential([keras.layers.Dense(10, input_shape=(5,))]) + + # Model is not built yet (no explicit build() call) + # Export should still work by building the model + export_path = os.path.join(self.temp_dir, "unbuilt_model.tflite") + model.export(export_path, format="litert") + + # Should succeed + self.assertTrue(os.path.exists(export_path)) + + def test_signature_def_with_causal_lm(self): + """Test that SignatureDef preserves input names for CausalLM models.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM with named inputs + class SimpleCausalLM(CausalLM): + def __init__(self): + super().__init__() + + # Create a mock preprocessor with sequence_length + class MockPreprocessor: + def __init__(self): + self.sequence_length = 128 + + self.preprocessor = MockPreprocessor() + self.embedding = keras.layers.Embedding(1000, 64) + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + return self.dense(x) + + model = SimpleCausalLM() + model.build( + input_shape={ + "token_ids": (None, 128), + "padding_mask": (None, 128), + } + ) + + # Export the model + export_path = os.path.join(self.temp_dir, "causal_lm_signature") + model.export(export_path, format="litert") + + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and check SignatureDef + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Get SignatureDef information + signature_defs = interpreter.get_signature_list() + self.assertIn("serving_default", signature_defs) + + serving_sig = signature_defs["serving_default"] + sig_inputs = serving_sig.get("inputs", []) + sig_outputs = serving_sig.get("outputs", []) + + # Verify SignatureDef has inputs and outputs + self.assertGreater( + len(sig_inputs), 0, "Should have at least one input in SignatureDef" + ) + self.assertGreater( + len(sig_outputs), + 0, + "Should have at least one output in SignatureDef", + ) + + # Verify that dictionary input names are preserved + # For CausalLM models, we expect token_ids and padding_mask + # sig_inputs is a list of input names + self.assertIn( + "token_ids", + sig_inputs, + f"Input name 'token_ids' should be in SignatureDef " + f"inputs: {sig_inputs}", + ) + self.assertIn( + "padding_mask", + sig_inputs, + f"Input name 'padding_mask' should be in SignatureDef " + f"inputs: {sig_inputs}", + ) + + # Verify inference works with the named signature + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + seq_len = 128 + test_token_ids = np.random.randint( + 0, 1000, (1, seq_len), dtype=input_details[0]["dtype"] + ) + test_padding_mask = np.ones( + (1, seq_len), dtype=input_details[1]["dtype"] + ) + + interpreter.set_tensor(input_details[0]["index"], test_token_ids) + interpreter.set_tensor(input_details[1]["index"], test_padding_mask) + interpreter.invoke() + + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], seq_len) # Sequence length + + # Clean up + del interpreter + import gc + + gc.collect() diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index e69de29bb2..1c02bb93f3 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -0,0 +1,4 @@ +"""Keras-Hub models module. + +This module contains all the task and backbone models available in Keras-Hub. +""" diff --git a/keras_hub/src/models/albert/albert_text_classifier_test.py b/keras_hub/src/models/albert/albert_text_classifier_test.py index 3d6413ff99..d9ab9c70d0 100644 --- a/keras_hub/src/models/albert/albert_text_classifier_test.py +++ b/keras_hub/src/models/albert/albert_text_classifier_test.py @@ -61,6 +61,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=AlbertTextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in AlbertTextClassifier.presets: diff --git a/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py b/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py index f525908b67..983b71610f 100644 --- a/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py +++ b/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py @@ -149,6 +149,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=BartSeq2SeqLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in BartSeq2SeqLM.presets: diff --git a/keras_hub/src/models/basnet/basnet_test.py b/keras_hub/src/models/basnet/basnet_test.py index b5bbe405e2..7af901ffd8 100644 --- a/keras_hub/src/models/basnet/basnet_test.py +++ b/keras_hub/src/models/basnet/basnet_test.py @@ -3,6 +3,9 @@ from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -26,7 +29,9 @@ def setUp(self): image_encoder=self.image_encoder, num_classes=1, ) - self.preprocessor = BASNetPreprocessor() + self.preprocessor = BASNetPreprocessor( + image_converter=BASNetImageConverter(height=64, width=64) + ) self.init_kwargs = { "backbone": self.backbone, "preprocessor": self.preprocessor, @@ -49,6 +54,14 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=BASNetImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + def test_end_to_end_model_predict(self): model = BASNetImageSegmenter(**self.init_kwargs) output = model.predict(self.images) diff --git a/keras_hub/src/models/bert/bert_text_classifier_test.py b/keras_hub/src/models/bert/bert_text_classifier_test.py index 606be7c839..2aacfa53d6 100644 --- a/keras_hub/src/models/bert/bert_text_classifier_test.py +++ b/keras_hub/src/models/bert/bert_text_classifier_test.py @@ -53,6 +53,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=BertTextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_smallest_preset(self): self.run_preset_test( diff --git a/keras_hub/src/models/bloom/bloom_causal_lm_test.py b/keras_hub/src/models/bloom/bloom_causal_lm_test.py index ada3d8eeb1..c6fc6de3e9 100644 --- a/keras_hub/src/models/bloom/bloom_causal_lm_test.py +++ b/keras_hub/src/models/bloom/bloom_causal_lm_test.py @@ -164,6 +164,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=BloomCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in BloomCausalLM.presets: diff --git a/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py b/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py index 6a5ee517e1..016d6ad478 100644 --- a/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py +++ b/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py @@ -76,3 +76,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=CSPNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/d_fine/d_fine_object_detector_test.py b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py index 3b3bfe14c0..414701cd0b 100644 --- a/keras_hub/src/models/d_fine/d_fine_object_detector_test.py +++ b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py @@ -152,3 +152,30 @@ def test_saved_model(self): init_kwargs=init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + backbone = DFineBackbone(**self.base_backbone_kwargs) + init_kwargs = { + "backbone": backbone, + "num_classes": 4, + "bounding_box_format": self.bounding_box_format, + "preprocessor": self.preprocessor, + } + + # D-Fine ObjectDetector only takes images as input + input_data = self.images + + self.run_litert_export_test( + cls=DFineObjectDetector, + init_kwargs=init_kwargs, + input_data=input_data, + comparison_mode="statistical", + output_thresholds={ + "intermediate_predicted_corners": {"max": 5.0, "mean": 0.05}, + "intermediate_logits": {"max": 5.0, "mean": 0.1}, + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + ) diff --git a/keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_test.py b/keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_test.py index 11f3d139ee..3f443ae366 100644 --- a/keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_test.py +++ b/keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_test.py @@ -64,6 +64,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=DebertaV3TextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in DebertaV3TextClassifier.presets: diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py index 065bed3caa..5a352ad021 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter_test.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest @@ -70,3 +71,19 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=DeepLabV3ImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.images, + comparison_mode="statistical", + output_thresholds={ + "*": {"max": 0.6, "mean": 0.3}, + }, + ) diff --git a/keras_hub/src/models/deit/deit_image_classifier_test.py b/keras_hub/src/models/deit/deit_image_classifier_test.py index d64a956cdc..5c784ccf19 100644 --- a/keras_hub/src/models/deit/deit_image_classifier_test.py +++ b/keras_hub/src/models/deit/deit_image_classifier_test.py @@ -55,3 +55,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=DeiTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/densenet/densenet_image_classifier_test.py b/keras_hub/src/models/densenet/densenet_image_classifier_test.py index 481005ba7e..18d622d79c 100644 --- a/keras_hub/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_hub/src/models/densenet/densenet_image_classifier_test.py @@ -61,6 +61,14 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=DenseNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in DenseNetImageClassifier.presets: diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py index 6277078488..5fedad8131 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py @@ -85,6 +85,16 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=DepthAnythingDepthEstimator, + init_kwargs=self.init_kwargs, + input_data=self.images, + comparison_mode="statistical", + output_thresholds={"depths": {"max": 2e-4, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): images = np.ones((2, 518, 518, 3), dtype="float32") diff --git a/keras_hub/src/models/distil_bert/distil_bert_text_classifier_test.py b/keras_hub/src/models/distil_bert/distil_bert_text_classifier_test.py index 71fdfc52b4..db57d21d0e 100644 --- a/keras_hub/src/models/distil_bert/distil_bert_text_classifier_test.py +++ b/keras_hub/src/models/distil_bert/distil_bert_text_classifier_test.py @@ -59,6 +59,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=DistilBertTextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in DistilBertTextClassifier.presets: diff --git a/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py index d2b5717b68..18f13e5505 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py +++ b/keras_hub/src/models/efficientnet/efficientnet_image_classifier_test.py @@ -7,6 +7,12 @@ from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( EfficientNetImageClassifier, ) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( # noqa: E501 + EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( + EfficientNetImageConverter, +) from keras_hub.src.tests.test_case import TestCase @@ -38,6 +44,9 @@ def setUp(self): self.init_kwargs = { "backbone": backbone, "num_classes": 1000, + "preprocessor": EfficientNetImageClassifierPreprocessor( + image_converter=EfficientNetImageConverter(image_size=(16, 16)) + ), } self.train_data = (self.images, self.labels) @@ -82,3 +91,11 @@ def test_all_presets(self): input_data=self.images, expected_output_shape=(2, 2), ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=EfficientNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py index 8eeec2b40d..58103a448e 100644 --- a/keras_hub/src/models/esm/esm_classifier_test.py +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -51,3 +51,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=ESMProteinClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/f_net/f_net_text_classifier_test.py b/keras_hub/src/models/f_net/f_net_text_classifier_test.py index 4658e795f6..fab0ed1650 100644 --- a/keras_hub/src/models/f_net/f_net_text_classifier_test.py +++ b/keras_hub/src/models/f_net/f_net_text_classifier_test.py @@ -1,6 +1,7 @@ import os import pytest +import tensorflow as tf from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier @@ -57,6 +58,25 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + # Add padding_mask to input_data for LiteRT export compatibility + input_data = self.input_data.copy() + batch_size, seq_length = input_data["token_ids"].shape + input_data["padding_mask"] = tf.zeros( + (batch_size, seq_length), dtype=tf.int32 + ) + + self.run_litert_export_test( + cls=FNetTextClassifier, + init_kwargs=self.init_kwargs, + input_data=input_data, + comparison_mode="statistical", + output_thresholds={ + "*": {"max": 0.01, "mean": 0.005}, + }, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in FNetTextClassifier.presets: diff --git a/keras_hub/src/models/falcon/falcon_causal_lm_test.py b/keras_hub/src/models/falcon/falcon_causal_lm_test.py index 393f8a8e97..c8b699b818 100644 --- a/keras_hub/src/models/falcon/falcon_causal_lm_test.py +++ b/keras_hub/src/models/falcon/falcon_causal_lm_test.py @@ -164,6 +164,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=FalconCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in FalconCausalLM.presets: diff --git a/keras_hub/src/models/gemma/gemma_causal_lm_test.py b/keras_hub/src/models/gemma/gemma_causal_lm_test.py index 7885d502cc..484140debd 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm_test.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm_test.py @@ -201,6 +201,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for GemmaCausalLM with small test model.""" + model = GemmaCausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = ops.cast( + input_data["padding_mask"], "int32" + ) + + expected_output_shape = ( + 2, + 8, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_test.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_test.py index ad37403752..63633b9164 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_test.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_test.py @@ -226,6 +226,71 @@ def test_saved_model(self, modality_type): input_data=input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for Gemma3CausalLM with small test model.""" + # Use the small text-only model for fast testing + model = Gemma3CausalLM(**self.text_init_kwargs) + + # Test with text input data + input_data = self.text_input_data.copy() + # Convert boolean padding_mask to int32 for LiteRT compatibility + if "padding_mask" in input_data: + input_data["padding_mask"] = tf.cast( + input_data["padding_mask"], tf.int32 + ) + + expected_output_shape = ( + 2, + 20, + self.text_preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-2, "mean": 1e-4}}, + ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export_multimodal(self): + """Test LiteRT export for multimodal Gemma3CausalLM with small test + model.""" + # Use the small multimodal model for testing + model = Gemma3CausalLM(**self.init_kwargs) + + # Test with multimodal input data + input_data = self.input_data.copy() + # Convert boolean padding_mask to int32 for LiteRT compatibility + if "padding_mask" in input_data: + input_data["padding_mask"] = tf.cast( + input_data["padding_mask"], tf.int32 + ) + + expected_output_shape = ( + 2, + 20, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-2, "mean": 1e-4}}, + ) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py index 0f6315bea6..7cf83aa5e9 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import keras import pytest from keras import ops @@ -106,6 +107,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for GPT2CausalLM with small test model.""" + model = GPT2CausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = ops.cast( + input_data["padding_mask"], "int32" + ) + + expected_output_shape = ( + 2, + 8, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in GPT2CausalLM.presets: diff --git a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py index f66c748b9e..305e6dc267 100644 --- a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py +++ b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import keras import pytest from keras import ops @@ -105,3 +106,19 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=GPTNeoXCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + output_thresholds={ + "max": 1e-3, + "mean": 1e-4, + }, # More lenient thresholds for numerical differences + ) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py index f294a23b72..8eb16b3cad 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -89,3 +89,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=HGNetV2ImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/llama/llama_causal_lm_test.py b/keras_hub/src/models/llama/llama_causal_lm_test.py index 1ff5a3a987..0e14faa34e 100644 --- a/keras_hub/src/models/llama/llama_causal_lm_test.py +++ b/keras_hub/src/models/llama/llama_causal_lm_test.py @@ -1,6 +1,7 @@ import os from unittest.mock import patch +import keras import pytest from keras import ops @@ -106,6 +107,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in LlamaCausalLM.presets: diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_test.py index a054b8ae14..346d1cf500 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_test.py @@ -1,6 +1,8 @@ from unittest.mock import patch +import keras import pytest +import tensorflow as tf from keras import ops from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone @@ -114,6 +116,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for Llama3CausalLM with small test model.""" + model = Llama3CausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = tf.cast( + input_data["padding_mask"], tf.int32 + ) + + expected_output_shape = ( + 2, + 7, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in Llama3CausalLM.presets: diff --git a/keras_hub/src/models/mistral/mistral_causal_lm_test.py b/keras_hub/src/models/mistral/mistral_causal_lm_test.py index 8a6bd42434..05d82f1e69 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm_test.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm_test.py @@ -1,6 +1,7 @@ import os from unittest.mock import patch +import keras import pytest from keras import ops @@ -106,6 +107,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for MistralCausalLM with small test model.""" + model = MistralCausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = ops.cast( + input_data["padding_mask"], "int32" + ) + + expected_output_shape = ( + 2, + 8, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in MistralCausalLM.presets: diff --git a/keras_hub/src/models/mit/mit_image_classifier_test.py b/keras_hub/src/models/mit/mit_image_classifier_test.py index c63a456311..a0c621b2d2 100644 --- a/keras_hub/src/models/mit/mit_image_classifier_test.py +++ b/keras_hub/src/models/mit/mit_image_classifier_test.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest @@ -50,3 +51,15 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=MiTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py b/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py index a711a06b0e..14c0e1f84f 100644 --- a/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py @@ -1,6 +1,7 @@ import os from unittest.mock import patch +import keras import pytest from keras import ops @@ -107,6 +108,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=MixtralCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in MixtralCausalLM.presets: diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index c996122fa5..27e41bcff9 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -101,3 +101,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py index 219cb6f285..494e1dab84 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py @@ -74,3 +74,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=MobileNetV5ImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py index 5d0a7dbe7a..a34bfe8ba1 100644 --- a/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py +++ b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py @@ -145,6 +145,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=MoonshineAudioToText, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in MoonshineAudioToText.presets: diff --git a/keras_hub/src/models/opt/opt_causal_lm_test.py b/keras_hub/src/models/opt/opt_causal_lm_test.py index 138c5a5180..6a9aa12262 100644 --- a/keras_hub/src/models/opt/opt_causal_lm_test.py +++ b/keras_hub/src/models/opt/opt_causal_lm_test.py @@ -105,6 +105,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=OPTCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in OPTCausalLM.presets: diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py index 1f53cdef04..d6e28f9cdf 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py @@ -1,5 +1,6 @@ import os.path +import keras import numpy as np import pytest @@ -106,6 +107,39 @@ def test_saved_model(self): input_data=input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + input_data = { + "token_ids": np.random.randint( + 0, + self.vocabulary_size, + size=(self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "images": np.ones( + (self.batch_size, self.image_size, self.image_size, 3) + ), + "padding_mask": np.ones( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "response_mask": np.zeros( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + } + self.run_litert_export_test( + cls=PaliGemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=input_data, + comparison_mode="statistical", + output_thresholds={"*": {"max": 2e-6, "mean": 1e-6}}, + ) + def test_pali_gemma_causal_model(self): preprocessed, _, _ = self.preprocessor( { diff --git a/keras_hub/src/models/parseq/parseq_causal_lm_test.py b/keras_hub/src/models/parseq/parseq_causal_lm_test.py index 177c596521..ba2ebb0117 100644 --- a/keras_hub/src/models/parseq/parseq_causal_lm_test.py +++ b/keras_hub/src/models/parseq/parseq_causal_lm_test.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest @@ -101,3 +102,34 @@ def test_causal_lm_basics(self): train_data=self.train_data, expected_output_shape=expected_shape_full, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + # Create input data for export test + input_data = { + "images": np.random.randn( + self.batch_size, + self.image_height, + self.image_width, + self.num_channels, + ), + "token_ids": np.random.randint( + 0, + self.vocabulary_size, + (self.batch_size, self.max_label_length), + ), + "padding_mask": np.ones( + (self.batch_size, self.max_label_length), dtype="int32" + ), + } + self.run_litert_export_test( + cls=PARSeqCausalLM, + init_kwargs=self.init_kwargs, + input_data=input_data, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-4}}, + ) diff --git a/keras_hub/src/models/phi3/phi3_causal_lm_test.py b/keras_hub/src/models/phi3/phi3_causal_lm_test.py index fc6f6aabe5..2f7df336f2 100644 --- a/keras_hub/src/models/phi3/phi3_causal_lm_test.py +++ b/keras_hub/src/models/phi3/phi3_causal_lm_test.py @@ -1,6 +1,7 @@ import os from unittest.mock import patch +import keras import pytest from keras import ops @@ -107,6 +108,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for Phi3CausalLM with small test model.""" + model = Phi3CausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = ops.cast( + input_data["padding_mask"], "int32" + ) + + expected_output_shape = ( + 2, + 12, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in Phi3CausalLM.presets: diff --git a/keras_hub/src/models/qwen/qwen_causal_lm_test.py b/keras_hub/src/models/qwen/qwen_causal_lm_test.py index b1a715646e..ab363de0de 100644 --- a/keras_hub/src/models/qwen/qwen_causal_lm_test.py +++ b/keras_hub/src/models/qwen/qwen_causal_lm_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import keras import pytest from keras import ops @@ -113,6 +114,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=QwenCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in QwenCausalLM.presets: diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py b/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py index 5e0456b521..f4e1b44ce3 100644 --- a/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py +++ b/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import keras import pytest from keras import ops @@ -114,6 +115,36 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for Qwen3CausalLM with small test model.""" + model = Qwen3CausalLM(**self.init_kwargs) + + # Convert boolean padding_mask to int32 for LiteRT compatibility + input_data = self.input_data.copy() + if "padding_mask" in input_data: + input_data["padding_mask"] = ops.cast( + input_data["padding_mask"], "int32" + ) + + expected_output_shape = ( + 2, + 7, + self.preprocessor.tokenizer.vocabulary_size(), + ) + + self.run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in Qwen3CausalLM.presets: diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py index d342c1e165..c9282563f4 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py @@ -3,6 +3,7 @@ os.environ["KERAS_BACKEND"] = "jax" +import keras import pytest from keras import ops @@ -120,6 +121,19 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=Qwen3MoeCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in Qwen3MoeCausalLM.presets: diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py index ad1b8c3113..f9f3d9a1d0 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py @@ -139,6 +139,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=QwenMoeCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in QwenMoeCausalLM.presets: diff --git a/keras_hub/src/models/resnet/resnet_image_classifier_test.py b/keras_hub/src/models/resnet/resnet_image_classifier_test.py index 9bc5897fee..d963a51f99 100644 --- a/keras_hub/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_hub/src/models/resnet/resnet_image_classifier_test.py @@ -1,3 +1,4 @@ +import keras import pytest from keras import ops @@ -65,6 +66,25 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + """Test LiteRT export for ResNetImageClassifier with small test + model.""" + model = ResNetImageClassifier(**self.init_kwargs) + expected_output_shape = (2, 2) # 2 images, 2 classes + + self.run_litert_export_test( + model=model, + input_data=self.images, + expected_output_shape=expected_output_shape, + comparison_mode="statistical", + output_thresholds={"*": {"max": 5e-5, "mean": 1e-5}}, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in ResNetImageClassifier.presets: diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py index 5e01c802a5..e1987a2435 100644 --- a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -76,7 +76,7 @@ def setUp(self): "preprocessor": preprocessor, } - self.input_size = 512 + self.input_size = 800 self.images = np.random.uniform( low=0, high=255, size=(1, self.input_size, self.input_size, 3) ).astype("float32") @@ -108,3 +108,25 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + # ObjectDetector models need both images and image_shape as inputs + # ObjectDetector only needs images input (not image_shape) + input_data = self.images + + self.run_litert_export_test( + cls=RetinaNetObjectDetector, + init_kwargs=self.init_kwargs, + input_data=input_data, + comparison_mode="statistical", + output_thresholds={ + "enc_topk_logits": {"max": 5.0, "mean": 0.05}, + "logits": {"max": 2.0, "mean": 0.05}, + "*": {"max": 1.5, "mean": 0.05}, + }, + ) diff --git a/keras_hub/src/models/roberta/roberta_text_classifier_test.py b/keras_hub/src/models/roberta/roberta_text_classifier_test.py index c5534a0dc4..adc3daa3ba 100644 --- a/keras_hub/src/models/roberta/roberta_text_classifier_test.py +++ b/keras_hub/src/models/roberta/roberta_text_classifier_test.py @@ -59,6 +59,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=RobertaTextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in RobertaTextClassifier.presets: diff --git a/keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_test.py b/keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_test.py index b24395c574..22a038c538 100644 --- a/keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_test.py +++ b/keras_hub/src/models/roformer_v2/roformer_v2_text_classifier_test.py @@ -1,3 +1,5 @@ +import pytest + from keras_hub.src.models.roformer_v2 import ( roformer_v2_text_classifier_preprocessor as r, ) @@ -50,3 +52,30 @@ def test_classifier_basics(self): train_data=self.train_data, expected_output_shape=(2, 2), ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=RoformerV2TextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=RoformerV2TextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in RoformerV2TextClassifier.presets: + self.run_preset_test( + cls=RoformerV2TextClassifier, + preset=preset, + init_kwargs={"num_classes": 2}, + input_data=self.input_data, + expected_output_shape=(2, 2), + ) diff --git a/keras_hub/src/models/sam/sam_image_segmenter_test.py b/keras_hub/src/models/sam/sam_image_segmenter_test.py index 0d36c31db2..d2c3fa88a1 100644 --- a/keras_hub/src/models/sam/sam_image_segmenter_test.py +++ b/keras_hub/src/models/sam/sam_image_segmenter_test.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest @@ -22,6 +23,12 @@ def setUp(self): (self.batch_size, self.image_size, self.image_size, 3), dtype="float32", ) + # Use more realistic SAM configuration for export testing + # Real SAM uses 64x64 embeddings for 1024x1024 images + # Scale down proportionally: 128/1024 = 1/8, + # so embeddings should be 64/8 = 8 + # But keep it simple for testing + embedding_size = self.image_size // 16 # 128/16 = 8 self.image_encoder = ViTDetBackbone( hidden_size=16, num_layers=16, @@ -35,7 +42,10 @@ def setUp(self): ) self.prompt_encoder = SAMPromptEncoder( hidden_size=8, - image_embedding_size=(8, 8), + image_embedding_size=( + embedding_size, + embedding_size, + ), # Match image encoder output input_image_size=( self.image_size, self.image_size, @@ -70,9 +80,10 @@ def setUp(self): "points": np.ones((self.batch_size, 1, 2), dtype="float32"), "labels": np.ones((self.batch_size, 1), dtype="float32"), "boxes": np.ones((self.batch_size, 1, 2, 2), dtype="float32"), - "masks": np.zeros( - (self.batch_size, 0, self.image_size, self.image_size, 1) - ), + # For TFLite export, use 1 mask filled with + # zeros (interpreted as "no mask") + # Use the expected mask size of 4 * image_embedding_size = 32 + "masks": np.zeros((self.batch_size, 1, 32, 32, 1), dtype="float32"), } self.labels = { "masks": np.ones((self.batch_size, 2), dtype="float32"), @@ -124,3 +135,23 @@ def test_all_presets(self): "iou_pred": [2], }, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=SAMImageSegmenter, + init_kwargs=self.init_kwargs, + input_data=self.inputs, + comparison_mode="statistical", + output_thresholds={ + "masks": {"max": 1e-3, "mean": 1e-4}, + "iou_pred": {"max": 1e-3, "mean": 1e-4}, + }, + target_spec={ + "experimental_disable_xnnpack": True + }, # Disable XNNPack delegate to avoid runtime issues + ) diff --git a/keras_hub/src/models/sam/sam_prompt_encoder.py b/keras_hub/src/models/sam/sam_prompt_encoder.py index 12b77f4a7d..883903415c 100644 --- a/keras_hub/src/models/sam/sam_prompt_encoder.py +++ b/keras_hub/src/models/sam/sam_prompt_encoder.py @@ -292,7 +292,7 @@ def _maybe_input_mask_embed(): ) dense_embeddings = ops.cond( - ops.equal(ops.size(masks), 0), + ops.equal(ops.shape(masks)[1], 0), _no_mask_embed, _maybe_input_mask_embed, ) diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py index 136351e386..8227399b57 100644 --- a/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py +++ b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py @@ -72,3 +72,13 @@ def test_saved_model(self): init_kwargs={**self.init_kwargs}, input_data=self.input_data, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=SegFormerImageSegmenter, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + comparison_mode="statistical", + output_thresholds={"*": {"max": 10.0, "mean": 2.0}}, + ) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py b/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py index cbf9b3f88e..f23fda0dc0 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import keras import pytest from keras import ops @@ -122,6 +123,18 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=SmolLM3CausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in SmolLM3CausalLM.presets: diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index 10ba8c5149..b9bf784811 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -196,3 +196,22 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=StableDiffusion3TextToImage, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + allow_custom_ops=True, # Allow custom ops like Erfc + target_spec={ + "supported_ops": [ + "tf.lite.OpsSet.TFLITE_BUILTINS", + "tf.lite.OpsSet.SELECT_TF_OPS", + ] + }, # Also specify supported ops + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py index 0a4cb0ef4e..fe258524ad 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py @@ -156,6 +156,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=T5GemmaSeq2SeqLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in T5GemmaSeq2SeqLM.presets: diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index d273759b46..30a8e80d6a 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -369,3 +369,90 @@ def add_layer(layer, info): print_fn=print_fn, **kwargs, ) + + def export(self, filepath, format="litert", verbose=False, **kwargs): + """Export the Keras-Hub model to the specified format. + + This method overrides `keras.Model.export()` to provide specialized + handling for Keras-Hub models with dictionary inputs. + + Args: + filepath: `str`. Path where to save the exported model. + format: `str`. Export format. Currently supports "litert" for + TensorFlow Lite export, as well as other formats supported by + the parent `keras.Model.export()` method (e.g., + "tf_saved_model"). + verbose: `bool`. Whether to print verbose output during export. + Defaults to `False`. + **kwargs: Additional arguments passed to the exporter. For LiteRT + export, common options include: + - `optimizations`: List of TFLite optimizations (e.g., + `[tf.lite.Optimize.DEFAULT]`) + - `allow_custom_ops`: Whether to allow custom TFLite operations. + Set to `True` for models using unsupported ops. Defaults + to `False`. + - `enable_select_tf_ops`: Whether to enable TensorFlow Select + ops (Flex delegate). Set to `True` for models using certain + TF operations not natively supported in TFLite. Defaults + to `False`. + + Examples: + + ```python + # Export a text model to TensorFlow Lite + model = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en") + model.export("gemma_model.tflite", format="litert") + + # Export with quantization + import tensorflow as tf + model.export( + "gemma_model_quantized.tflite", + format="litert", + optimizations=[tf.lite.Optimize.DEFAULT] + ) + + # Export model with custom TFLite operations + # (e.g., StableDiffusion3 with Erfc op) + model.export( + "sd3_model.tflite", + format="litert", + allow_custom_ops=True + ) + + # Export model with TensorFlow Select ops (Flex delegate) + model.export( + "model_with_flex.tflite", + format="litert", + enable_select_tf_ops=True + ) + ``` + """ + if format == "litert": + # Ensure filepath ends with .tflite + if not filepath.endswith(".tflite"): + filepath = filepath + ".tflite" + + from keras.src.export.litert import export_litert + + from keras_hub.src.export.configs import get_exporter_config + + # Get the appropriate configuration for this model type + config = get_exporter_config(self) + + # Get domain-specific input signature from config + input_signature = config.get_input_signature() + + export_kwargs = kwargs.copy() + # Note: verbose is handled at the keras-hub level, + # not passed to core export + + # Call Keras Core's export_litert directly + export_litert( + self, + filepath, + input_signature=input_signature, + **export_kwargs, + ) + else: + # Fall back to parent class (keras.Model) export for other formats + super().export(filepath, format=format, verbose=verbose, **kwargs) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index 16c3fa4453..485b6fff43 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest @@ -52,6 +53,18 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", + ) + def test_litert_export(self): + self.run_litert_export_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + @pytest.mark.extra_large def test_all_presets(self): # we need at least 32x32 image resolution here to satisfy the presets' diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index 1734642bd6..8dfd7a34e2 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -55,3 +55,11 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/xception/xception_image_classifier_test.py b/keras_hub/src/models/xception/xception_image_classifier_test.py index c042ecf2d7..a20308fb8a 100644 --- a/keras_hub/src/models/xception/xception_image_classifier_test.py +++ b/keras_hub/src/models/xception/xception_image_classifier_test.py @@ -74,6 +74,14 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=XceptionImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in XceptionImageClassifier.presets: diff --git a/keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_test.py b/keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_test.py index 386d807917..d56f144f0e 100644 --- a/keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_test.py +++ b/keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_test.py @@ -64,6 +64,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_litert_export(self): + self.run_litert_export_test( + cls=XLMRobertaTextClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in XLMRobertaTextClassifier.presets: diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 633f32cd5b..bce71ca20b 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -1,7 +1,9 @@ +import gc import json import os import pathlib import re +import tempfile import keras import numpy as np @@ -433,6 +435,387 @@ def run_model_saving_test( restored_output = restored_model(input_data) self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) + def _verify_litert_outputs( + self, + keras_output, + litert_output, + sig_outputs, + expected_output_shape=None, + verify_numerics=True, + comparison_mode="strict", + output_thresholds=None, + ): + """Verify LiteRT outputs against expected shape and Keras outputs. + + Args: + keras_output: Keras model output (can be None if not verifying + numerics) + litert_output: LiteRT interpreter output + sig_outputs: Output names from SignatureDef + expected_output_shape: Expected output shape (optional) + verify_numerics: Whether to verify numerical correctness + comparison_mode: "strict" or "statistical" + output_thresholds: Thresholds for statistical comparison + """ + # Handle single output case: if Keras has single output but LiteRT + # returns dict + if ( + not isinstance(keras_output, dict) + and isinstance(litert_output, dict) + and len(litert_output) == 1 + ): + litert_output = list(litert_output.values())[0] + + # Verify output shape if specified + if expected_output_shape is not None: + self.assertEqual(litert_output.shape, expected_output_shape) + + # Verify numerical correctness if requested + if verify_numerics: + self._verify_outputs( + keras_output, + litert_output, + sig_outputs, + output_thresholds, + comparison_mode, + ) + + def _verify_outputs( + self, + keras_output, + litert_output, + sig_outputs, + output_thresholds, + comparison_mode, + ): + """Verify numerical accuracy between Keras and LiteRT outputs. + + This method compares outputs using the SignatureDef output names to + match Keras outputs with LiteRT outputs properly. + + Args: + keras_output: Keras model output (tensor or dict) + litert_output: LiteRT interpreter output (tensor or dict) + sig_outputs: List of output names from SignatureDef + output_thresholds: Dict of thresholds for comparison + comparison_mode: "strict" or "statistical" + """ + if isinstance(keras_output, dict) and isinstance(litert_output, dict): + # Both outputs are dicts - compare using SignatureDef output names + for output_name in sig_outputs: + if output_name not in keras_output: + self.fail( + f"SignatureDef output '{output_name}' not found in " + f"Keras outputs.\n" + f"Keras keys: {list(keras_output.keys())}" + ) + if output_name not in litert_output: + self.fail( + f"SignatureDef output '{output_name}' not found in " + f"LiteRT outputs.\n" + f"LiteRT keys: {list(litert_output.keys())}" + ) + + keras_val_np = ops.convert_to_numpy(keras_output[output_name]) + litert_val = litert_output[output_name] + output_threshold = output_thresholds.get( + output_name, + output_thresholds.get("*", {"max": 10.0, "mean": 0.1}), + ) + self._compare_outputs( + keras_val_np, + litert_val, + comparison_mode, + output_name, + output_threshold["max"], + output_threshold["mean"], + ) + elif not isinstance(keras_output, dict) and not isinstance( + litert_output, dict + ): + # Both outputs are single tensors - direct comparison + keras_output_np = ops.convert_to_numpy(keras_output) + output_threshold = output_thresholds.get( + "*", {"max": 10.0, "mean": 0.1} + ) + self._compare_outputs( + keras_output_np, + litert_output, + comparison_mode, + key=None, + max_threshold=output_threshold["max"], + mean_threshold=output_threshold["mean"], + ) + else: + keras_type = type(keras_output).__name__ + litert_type = type(litert_output).__name__ + self.fail( + f"Output structure mismatch: Keras returns " + f"{keras_type}, LiteRT returns {litert_type}" + ) + + def run_litert_export_test( + self, + cls=None, + init_kwargs=None, + input_data=None, + expected_output_shape=None, + model=None, + verify_numerics=True, + # No LiteRT output in model saving test; remove undefined return + output_thresholds=None, + **export_kwargs, + ): + """Export model to LiteRT format and verify outputs. + + Args: + cls: Model class to test (optional if model is provided) + init_kwargs: Initialization arguments for the model (optional + if model is provided) + input_data: Input data to test with (dict or tensor) + expected_output_shape: Expected output shape from LiteRT inference + model: Pre-created model instance (optional, if provided cls and + init_kwargs are ignored) + verify_numerics: Whether to verify numerical correctness + between Keras and LiteRT outputs. Set to False for preset + models with load_weights=False where outputs are random. + comparison_mode: "strict" (default) or "statistical". + - "strict": All elements must be within default tolerances + (1e-6) + - "statistical": Check mean/max absolute differences against + provided thresholds + output_thresholds: Dict mapping output names to threshold dicts + with "max" and "mean" keys. Use "*" as wildcard for defaults. + Example: {"output1": {"max": 1e-4, "mean": 1e-5}, + "*": {"max": 1e-3, "mean": 1e-4}} + **export_kwargs: Additional keyword arguments to pass to + model.export(), such as allow_custom_ops=True or + enable_select_tf_ops=True. + """ + # Extract comparison_mode from export_kwargs if provided + comparison_mode = export_kwargs.pop("comparison_mode", "strict") + if keras.backend.backend() != "tensorflow": + self.skipTest("LiteRT export only supports TensorFlow backend") + + try: + from ai_edge_litert.interpreter import Interpreter + except ImportError: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + + if output_thresholds is None: + output_thresholds = {"*": {"max": 10.0, "mean": 0.1}} + + if model is None: + if cls is None or init_kwargs is None: + raise ValueError( + "Either 'model' or 'cls' and 'init_kwargs' must be provided" + ) + model = cls(**init_kwargs) + _ = model(input_data) + + interpreter = None + try: + with tempfile.TemporaryDirectory() as temp_dir: + export_path = os.path.join(temp_dir, "model.tflite") + + # Step 1: Export model and get Keras output + model.export(export_path, format="litert", **export_kwargs) + self.assertTrue(os.path.exists(export_path)) + self.assertGreater(os.path.getsize(export_path), 0) + + keras_output = model(input_data) if verify_numerics else None + + # Step 2: Load interpreter and verify SignatureDef + interpreter = Interpreter(model_path=export_path) + signature_defs = interpreter.get_signature_list() + self.assertIn( + "serving_default", + signature_defs, + "Missing serving_default signature", + ) + + serving_sig = signature_defs["serving_default"] + sig_inputs = serving_sig.get("inputs", []) + sig_outputs = serving_sig.get("outputs", []) + + self.assertGreater( + len(sig_inputs), + 0, + "Should have at least one input in SignatureDef", + ) + self.assertGreater( + len(sig_outputs), + 0, + "Should have at least one output in SignatureDef", + ) + + # Verify input signature + if isinstance(input_data, dict): + expected_inputs = set(input_data.keys()) + actual_inputs = set(sig_inputs) + # Check that all expected inputs are in the signature + # (allow signature to have additional optional inputs) + missing_inputs = expected_inputs - actual_inputs + if missing_inputs: + self.fail( + f"Missing inputs in SignatureDef: " + f"{sorted(missing_inputs)}. " + f"Expected: {sorted(expected_inputs)}, " + f"SignatureDef has: {sorted(actual_inputs)}" + ) + else: + # For numpy arrays, just verify we have exactly one input + # (since we're passing a single tensor) + if len(sig_inputs) != 1: + self.fail( + "Expected 1 input for numpy array input_data, " + f"but SignatureDef has {len(sig_inputs)}: " + f"{sig_inputs}" + ) + + # Verify output signature + if verify_numerics and isinstance(keras_output, dict): + expected_outputs = set(keras_output.keys()) + actual_outputs = set(sig_outputs) + if expected_outputs != actual_outputs: + self.fail( + f"Output name mismatch: Expected " + f"{sorted(expected_outputs)}, " + f"but SignatureDef has {sorted(actual_outputs)}" + ) + + # Step 3: Run LiteRT inference + os.remove(export_path) + # Simple inference implementation + runner = interpreter.get_signature_runner("serving_default") + + # Convert input data dtypes to match TFLite expectations + def convert_for_tflite(x): + """Convert tensor/array to TFLite-compatible dtypes.""" + if hasattr(x, "dtype"): + if isinstance(x, np.ndarray): + if x.dtype == bool: + return x.astype(np.int32) + elif x.dtype == np.float64: + return x.astype(np.float32) + elif x.dtype == np.int64: + return x.astype(np.int32) + else: # TensorFlow tensor + if x.dtype == tf.bool: + return tf.cast(x, tf.int32).numpy() + elif x.dtype == tf.float64: + return tf.cast(x, tf.float32).numpy() + elif x.dtype == tf.int64: + return tf.cast(x, tf.int32).numpy() + else: + return x.numpy() if hasattr(x, "numpy") else x + elif hasattr(x, "numpy"): + return x.numpy() + return x + + if isinstance(input_data, dict): + converted_input_data = tree.map_structure( + convert_for_tflite, input_data + ) + litert_output = runner(**converted_input_data) + else: + # For single tensor inputs, get the input name + sig_inputs = serving_sig.get("inputs", []) + input_name = sig_inputs[ + 0 + ] # We verified len(sig_inputs) == 1 above + converted_input = convert_for_tflite(input_data) + litert_output = runner(**{input_name: converted_input}) + + # Step 4: Verify outputs + self._verify_litert_outputs( + keras_output, + litert_output, + sig_outputs, + expected_output_shape=expected_output_shape, + verify_numerics=verify_numerics, + comparison_mode=comparison_mode, + output_thresholds=output_thresholds, + ) + finally: + if interpreter is not None: + del interpreter + if model is not None and cls is not None: + del model + gc.collect() + + def _compare_outputs( + self, + keras_val, + litert_val, + comparison_mode, + key=None, + max_threshold=10.0, + mean_threshold=0.1, + ): + """Compare Keras and LiteRT outputs using specified comparison mode. + + Args: + keras_val: Keras model output (numpy array) + litert_val: LiteRT model output (numpy array) + comparison_mode: "strict" or "statistical" + key: Output key name for error messages (optional) + max_threshold: Maximum absolute difference threshold for statistical + mode + mean_threshold: Mean absolute difference threshold for statistical + mode + """ + key_msg = f" for output key '{key}'" if key else "" + + # Check if shapes are compatible for comparison + self.assertEqual( + keras_val.shape, + litert_val.shape, + f"Shape mismatch{key_msg}: Keras shape " + f"{keras_val.shape}, LiteRT shape {litert_val.shape}. " + "Numerical comparison cannot proceed due to incompatible shapes.", + ) + + if comparison_mode == "strict": + # Original strict element-wise comparison with default tolerances + self.assertAllClose( + keras_val, + litert_val, + atol=1e-6, + rtol=1e-6, + msg=f"Mismatch{key_msg}", + ) + elif comparison_mode == "statistical": + # Statistical comparison + + # Calculate element-wise absolute differences + abs_diff = np.abs(keras_val - litert_val) + + # Element-wise statistics + mean_abs_diff = np.mean(abs_diff) + max_abs_diff = np.max(abs_diff) + + # Assert reasonable bounds on statistical differences + self.assertLessEqual( + mean_abs_diff, + mean_threshold, + f"Mean absolute difference too high: {mean_abs_diff:.6e}" + f"{key_msg} (threshold: {mean_threshold})", + ) + self.assertLessEqual( + max_abs_diff, + max_threshold, + f"Max absolute difference too high: {max_abs_diff:.6e}" + f"{key_msg} (threshold: {max_threshold})", + ) + else: + raise ValueError( + f"Unknown comparison_mode: {comparison_mode}. Must be " + "'strict' or 'statistical'" + ) + def run_backbone_test( self, cls, diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 94ab86d63f..5b366b8734 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -11,3 +11,6 @@ torchvision>=0.16.0 jax[cpu] -r requirements-common.txt + +# for litert export feature +ai-edge-litert