Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from collections.abc import Callable
from functools import partial
from types import ModuleType
from typing import Optional, Union

from ..modeling_flash_attention_utils import lazy_import_flash_attention
from ..utils import logging
from ..utils import ENV_VARS_TRUE_VALUES, logging
from .flash_attention import flash_attention_forward


Expand All @@ -32,10 +33,22 @@
get_kernel,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)

_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
_kernels_available = True
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES

def use_kernel_forward_from_hub(layer_name: str):
if _kernels_enabled:
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub

return _kernels_use_kernel_forward_from_hub(layer_name)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda cls: cls

_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"MultiScaleDeformableAttention": {
Expand Down Expand Up @@ -143,6 +156,7 @@

except ImportError:
_kernels_available = False
_kernels_enabled = False

# Stub to make decorators int transformers work when `kernels`
# is not installed.
Expand Down
122 changes: 122 additions & 0 deletions tests/integrations/test_hub_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import unittest
from unittest.mock import patch

from transformers.testing_utils import require_kernels


@require_kernels
class HubKernelsTests(unittest.TestCase):
def test_disable_hub_kernels(self):
"""
Test that _kernels_enabled is False when USE_HUB_KERNELS when USE_HUB_KERNELS=OFF
"""
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
# Re-import to ensure the environment variable takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Verify that kernels are disabled
self.assertFalse(hub_kernels._kernels_enabled)

def test_enable_hub_kernels_default(self):
"""
Test that _kernels_enabled is True when USE_HUB_KERNELS is not provided (default behavior)
"""
# Remove USE_HUB_KERNELS from the environment if it exists
env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
with patch.dict(os.environ, env_without_hub_kernels, clear=True):
# Re-import to ensure the environment variable change takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Verify that kernels are enabled by default
self.assertTrue(hub_kernels._kernels_enabled)

def test_enable_hub_kernels_on(self):
"""
Test that _kernels_enabled is True when USE_HUB_KERNELS=ON
"""
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
# Re-import to ensure the environment variable takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Verify that kernels are enabled
self.assertTrue(hub_kernels._kernels_enabled)

@patch("kernels.use_kernel_forward_from_hub")
def test_use_kernel_forward_from_hub_not_called_when_disabled(self, mocked_use_kernel_forward):
"""
Test that kernels.use_kernel_forward_from_hub is not called when USE_HUB_KERNELS is disabled
"""
# Set environment variable to disable hub kernels
with patch.dict(os.environ, {"USE_HUB_KERNELS": "OFF"}):
# Re-import to ensure the environment variable takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Call the function with a test layer name
decorator = hub_kernels.use_kernel_forward_from_hub("DummyLayer")

# Verify that the kernels function was never called
mocked_use_kernel_forward.assert_not_called()

# Verify that we get a no-op decorator
class FooClass:
pass

result = decorator(FooClass)
self.assertIs(result, FooClass)

@patch("kernels.use_kernel_forward_from_hub")
def test_use_kernel_forward_from_hub_called_when_enabled_default(self, mocked_use_kernel_forward):
"""
Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS is not set (default)
"""
# Remove USE_HUB_KERNELS from the environment if it exists
env_without_hub_kernels = {k: v for k, v in os.environ.items() if k != "USE_HUB_KERNELS"}
with patch.dict(os.environ, env_without_hub_kernels, clear=True):
# Re-import to ensure the environment variable change takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Call the function with a test layer name
hub_kernels.use_kernel_forward_from_hub("FooLayer")

# Verify that the kernels function was called once with the correct argument
mocked_use_kernel_forward.assert_called_once_with("FooLayer")

@patch("kernels.use_kernel_forward_from_hub")
def test_use_kernel_forward_from_hub_called_when_enabled_on(self, mocked_use_kernel_forward):
"""
Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS=ON
"""
with patch.dict(os.environ, {"USE_HUB_KERNELS": "ON"}):
# Re-import to ensure the environment variable change takes effect
import importlib

from transformers.integrations import hub_kernels

importlib.reload(hub_kernels)

# Call the function with a test layer name
hub_kernels.use_kernel_forward_from_hub("FooLayer")

# Verify that the kernels function was called once with the correct argument
mocked_use_kernel_forward.assert_called_once_with("FooLayer")