77
88@require_kernels
99class HubKernelsTests (unittest .TestCase ):
10-
1110 def test_disable_hub_kernels (self ):
1211 """
1312 Test that _kernels_enabled is False when USE_HUB_KERNELS when USE_HUB_KERNELS=OFF
1413 """
15- with patch .dict (os .environ , {' USE_HUB_KERNELS' : 'ON' }):
14+ with patch .dict (os .environ , {" USE_HUB_KERNELS" : "ON" }):
1615 # Re-import to ensure the environment variable takes effect
1716 import importlib
1817
1918 from transformers .integrations import hub_kernels
19+
2020 importlib .reload (hub_kernels )
2121
2222 # Verify that kernels are disabled
@@ -27,12 +27,13 @@ def test_enable_hub_kernels_default(self):
2727 Test that _kernels_enabled is True when USE_HUB_KERNELS is not provided (default behavior)
2828 """
2929 # Remove USE_HUB_KERNELS from the environment if it exists
30- env_without_hub_kernels = {k : v for k , v in os .environ .items () if k != ' USE_HUB_KERNELS' }
30+ env_without_hub_kernels = {k : v for k , v in os .environ .items () if k != " USE_HUB_KERNELS" }
3131 with patch .dict (os .environ , env_without_hub_kernels , clear = True ):
3232 # Re-import to ensure the environment variable change takes effect
3333 import importlib
3434
3535 from transformers .integrations import hub_kernels
36+
3637 importlib .reload (hub_kernels )
3738
3839 # Verify that kernels are enabled by default
@@ -42,27 +43,29 @@ def test_enable_hub_kernels_on(self):
4243 """
4344 Test that _kernels_enabled is True when USE_HUB_KERNELS=ON
4445 """
45- with patch .dict (os .environ , {' USE_HUB_KERNELS' : 'ON' }):
46+ with patch .dict (os .environ , {" USE_HUB_KERNELS" : "ON" }):
4647 # Re-import to ensure the environment variable takes effect
4748 import importlib
4849
4950 from transformers .integrations import hub_kernels
51+
5052 importlib .reload (hub_kernels )
5153
5254 # Verify that kernels are enabled
5355 self .assertTrue (hub_kernels ._kernels_enabled )
5456
55- @patch (' kernels.use_kernel_forward_from_hub' )
57+ @patch (" kernels.use_kernel_forward_from_hub" )
5658 def test_use_kernel_forward_from_hub_not_called_when_disabled (self , mocked_use_kernel_forward ):
5759 """
5860 Test that kernels.use_kernel_forward_from_hub is not called when USE_HUB_KERNELS is disabled
5961 """
6062 # Set environment variable to disable hub kernels
61- with patch .dict (os .environ , {' USE_HUB_KERNELS' : ' OFF' }):
63+ with patch .dict (os .environ , {" USE_HUB_KERNELS" : " OFF" }):
6264 # Re-import to ensure the environment variable takes effect
6365 import importlib
6466
6567 from transformers .integrations import hub_kernels
68+
6669 importlib .reload (hub_kernels )
6770
6871 # Call the function with a test layer name
@@ -78,18 +81,19 @@ class FooClass:
7881 result = decorator (FooClass )
7982 self .assertIs (result , FooClass )
8083
81- @patch (' kernels.use_kernel_forward_from_hub' )
84+ @patch (" kernels.use_kernel_forward_from_hub" )
8285 def test_use_kernel_forward_from_hub_called_when_enabled_default (self , mocked_use_kernel_forward ):
8386 """
8487 Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS is not set (default)
8588 """
8689 # Remove USE_HUB_KERNELS from the environment if it exists
87- env_without_hub_kernels = {k : v for k , v in os .environ .items () if k != ' USE_HUB_KERNELS' }
90+ env_without_hub_kernels = {k : v for k , v in os .environ .items () if k != " USE_HUB_KERNELS" }
8891 with patch .dict (os .environ , env_without_hub_kernels , clear = True ):
8992 # Re-import to ensure the environment variable change takes effect
9093 import importlib
9194
9295 from transformers .integrations import hub_kernels
96+
9397 importlib .reload (hub_kernels )
9498
9599 # Call the function with a test layer name
@@ -98,22 +102,21 @@ def test_use_kernel_forward_from_hub_called_when_enabled_default(self, mocked_us
98102 # Verify that the kernels function was called once with the correct argument
99103 mocked_use_kernel_forward .assert_called_once_with ("FooLayer" )
100104
101- @patch (' kernels.use_kernel_forward_from_hub' )
105+ @patch (" kernels.use_kernel_forward_from_hub" )
102106 def test_use_kernel_forward_from_hub_called_when_enabled_on (self , mocked_use_kernel_forward ):
103107 """
104108 Test that kernels.use_kernel_forward_from_hub is called when USE_HUB_KERNELS=ON
105109 """
106- with patch .dict (os .environ , {' USE_HUB_KERNELS' : 'ON' }):
110+ with patch .dict (os .environ , {" USE_HUB_KERNELS" : "ON" }):
107111 # Re-import to ensure the environment variable change takes effect
108112 import importlib
109113
110114 from transformers .integrations import hub_kernels
115+
111116 importlib .reload (hub_kernels )
112117
113118 # Call the function with a test layer name
114119 hub_kernels .use_kernel_forward_from_hub ("FooLayer" )
115120
116121 # Verify that the kernels function was called once with the correct argument
117122 mocked_use_kernel_forward .assert_called_once_with ("FooLayer" )
118-
119-
0 commit comments