|
17 | 17 |
|
18 | 18 | import copy |
19 | 19 | import types |
20 | | -from unittest.mock import patch |
| 20 | +from unittest.mock import MagicMock, patch |
21 | 21 |
|
22 | 22 | from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig |
23 | 23 | from transformers.integrations.hub_kernels import ( |
@@ -401,3 +401,74 @@ def spy_kernelize(model, device=None, mode=None): |
401 | 401 | self.assertTrue(any(m == Mode.TRAINING for m in last_modes)) |
402 | 402 | self.model.eval() |
403 | 403 | self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) |
| 404 | + |
| 405 | + |
| 406 | +@require_kernels |
| 407 | +class TestKernelMappingDeviceFiltering(TestCasePlus): |
| 408 | + """Test that kernel mappings correctly filter by current device.""" |
| 409 | + |
| 410 | + def test_multi_device_mapping_filters_correctly(self): |
| 411 | + """ |
| 412 | + Test that when a kernel_mapping contains multiple devices (cuda, rocm), |
| 413 | + only the current device's kernel is registered. |
| 414 | + Regression test for issue where ROCm overwrote CUDA mapping. |
| 415 | + """ |
| 416 | + kernel_mapping = { |
| 417 | + "RMSNorm": { |
| 418 | + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", |
| 419 | + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", |
| 420 | + } |
| 421 | + } |
| 422 | + |
| 423 | + kernel_config = KernelConfig(kernel_mapping) |
| 424 | + |
| 425 | + # Create a mock model on CUDA device |
| 426 | + mock_model = MagicMock() |
| 427 | + mock_model.training = False |
| 428 | + |
| 429 | + # Mock parameter with CUDA device |
| 430 | + mock_param = MagicMock() |
| 431 | + mock_param.device.type = "cuda" |
| 432 | + mock_model.parameters.return_value = iter([mock_param]) |
| 433 | + |
| 434 | + # Mock named_modules with RMSNorm layer |
| 435 | + mock_layer = MagicMock() |
| 436 | + mock_layer.kernel_layer_name = "RMSNorm" |
| 437 | + mock_model.named_modules.return_value = [("layers.0", mock_layer)] |
| 438 | + |
| 439 | + # Trigger the mapping creation |
| 440 | + kernel_config.create_compatible_mapping(mock_model) |
| 441 | + |
| 442 | + # Verify results |
| 443 | + result_mapping = kernel_config.kernel_mapping |
| 444 | + |
| 445 | + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") |
| 446 | + backends = list(result_mapping["RMSNorm"].keys()) |
| 447 | + |
| 448 | + # Assert only CUDA is present, not ROCm |
| 449 | + self.assertIn("cuda", backends, "CUDA backend should be registered") |
| 450 | + self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device") |
| 451 | + |
| 452 | + def test_single_device_mapping_still_works(self): |
| 453 | + """ |
| 454 | + Test that single-device mappings continue to work as expected. |
| 455 | + """ |
| 456 | + kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"} |
| 457 | + |
| 458 | + kernel_config = KernelConfig(kernel_mapping) |
| 459 | + |
| 460 | + # Create a mock model |
| 461 | + mock_model = MagicMock() |
| 462 | + mock_model.training = False |
| 463 | + |
| 464 | + mock_param = MagicMock() |
| 465 | + mock_param.device.type = "cuda" |
| 466 | + mock_model.parameters.return_value = iter([mock_param]) |
| 467 | + |
| 468 | + mock_layer = MagicMock() |
| 469 | + mock_layer.kernel_layer_name = "RMSNorm" |
| 470 | + mock_model.named_modules.return_value = [("layers.0", mock_layer)] |
| 471 | + kernel_config.create_compatible_mapping(mock_model) |
| 472 | + |
| 473 | + result_mapping = kernel_config.kernel_mapping |
| 474 | + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") |
0 commit comments