11import os
22import unittest
3-
43from unittest .mock import patch
4+
55from transformers .testing_utils import require_kernels
66
77
@@ -15,6 +15,7 @@ def test_disable_hub_kernels(self):
1515 with patch .dict (os .environ , {'USE_HUB_KERNELS' : 'ON' }):
1616 # Re-import to ensure the environment variable takes effect
1717 import importlib
18+
1819 from transformers .integrations import hub_kernels
1920 importlib .reload (hub_kernels )
2021
@@ -30,6 +31,7 @@ def test_enable_hub_kernels_default(self):
3031 with patch .dict (os .environ , env_without_hub_kernels , clear = True ):
3132 # Re-import to ensure the environment variable change takes effect
3233 import importlib
34+
3335 from transformers .integrations import hub_kernels
3436 importlib .reload (hub_kernels )
3537
@@ -43,6 +45,7 @@ def test_enable_hub_kernels_on(self):
4345 with patch .dict (os .environ , {'USE_HUB_KERNELS' : 'ON' }):
4446 # Re-import to ensure the environment variable takes effect
4547 import importlib
48+
4649 from transformers .integrations import hub_kernels
4750 importlib .reload (hub_kernels )
4851
@@ -58,19 +61,20 @@ def test_use_kernel_forward_from_hub_not_called_when_disabled(self, mocked_use_k
5861 with patch .dict (os .environ , {'USE_HUB_KERNELS' : 'OFF' }):
5962 # Re-import to ensure the environment variable takes effect
6063 import importlib
64+
6165 from transformers .integrations import hub_kernels
6266 importlib .reload (hub_kernels )
63-
67+
6468 # Call the function with a test layer name
6569 decorator = hub_kernels .use_kernel_forward_from_hub ("DummyLayer" )
66-
70+
6771 # Verify that the kernels function was never called
6872 mocked_use_kernel_forward .assert_not_called ()
69-
73+
7074 # Verify that we get a no-op decorator
7175 class FooClass :
7276 pass
73-
77+
7478 result = decorator (FooClass )
7579 self .assertIs (result , FooClass )
7680
@@ -84,12 +88,13 @@ def test_use_kernel_forward_from_hub_called_when_enabled_default(self, mocked_us
8488 with patch .dict (os .environ , env_without_hub_kernels , clear = True ):
8589 # Re-import to ensure the environment variable change takes effect
8690 import importlib
91+
8792 from transformers .integrations import hub_kernels
8893 importlib .reload (hub_kernels )
89-
94+
9095 # Call the function with a test layer name
9196 hub_kernels .use_kernel_forward_from_hub ("FooLayer" )
92-
97+
9398 # Verify that the kernels function was called once with the correct argument
9499 mocked_use_kernel_forward .assert_called_once_with ("FooLayer" )
95100
@@ -101,12 +106,13 @@ def test_use_kernel_forward_from_hub_called_when_enabled_on(self, mocked_use_ker
101106 with patch .dict (os .environ , {'USE_HUB_KERNELS' : 'ON' }):
102107 # Re-import to ensure the environment variable change takes effect
103108 import importlib
109+
104110 from transformers .integrations import hub_kernels
105111 importlib .reload (hub_kernels )
106-
112+
107113 # Call the function with a test layer name
108114 hub_kernels .use_kernel_forward_from_hub ("FooLayer" )
109-
115+
110116 # Verify that the kernels function was called once with the correct argument
111117 mocked_use_kernel_forward .assert_called_once_with ("FooLayer" )
112118
0 commit comments