Skip to content

Commit 3f17410

Browse files
authored
Kernel mapping error resolve (#42466)
* mapping error resolved with test check * Fix undefined variable 'device' in kernel_config * added test in test_kernels * added test with proper format * added test with proper format once again * Removed mapping_test.py file * reformated with ruff * removed the test
1 parent 53d2bf6 commit 3f17410

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

src/transformers/utils/kernel_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def create_compatible_mapping(self, model, compile=False):
208208
from kernels import Mode
209209

210210
compatible_mapping = {}
211+
current_device = infer_device(model)
211212
for layer_name, kernel in self.kernel_mapping.items():
212213
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
213214
mode = Mode.TRAINING if model.training else Mode.INFERENCE
@@ -216,10 +217,11 @@ def create_compatible_mapping(self, model, compile=False):
216217

217218
if isinstance(kernel, str):
218219
repo_name = kernel
219-
device = infer_device(model)
220-
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
220+
add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping)
221221
elif isinstance(kernel, dict):
222222
for device, repo_name in kernel.items():
223+
if device != current_device:
224+
continue
223225
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
224226

225227
self.kernel_mapping = compatible_mapping

tests/kernels/test_kernels.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import copy
1919
import types
20-
from unittest.mock import patch
20+
from unittest.mock import MagicMock, patch
2121

2222
from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
2323
from transformers.integrations.hub_kernels import (
@@ -401,3 +401,74 @@ def spy_kernelize(model, device=None, mode=None):
401401
self.assertTrue(any(m == Mode.TRAINING for m in last_modes))
402402
self.model.eval()
403403
self.assertTrue(any(m == Mode.INFERENCE for m in last_modes))
404+
405+
406+
@require_kernels
407+
class TestKernelMappingDeviceFiltering(TestCasePlus):
408+
"""Test that kernel mappings correctly filter by current device."""
409+
410+
def test_multi_device_mapping_filters_correctly(self):
411+
"""
412+
Test that when a kernel_mapping contains multiple devices (cuda, rocm),
413+
only the current device's kernel is registered.
414+
Regression test for issue where ROCm overwrote CUDA mapping.
415+
"""
416+
kernel_mapping = {
417+
"RMSNorm": {
418+
"cuda": "kernels-community/layer_norm:LlamaRMSNorm",
419+
"rocm": "kernels-community/layer_norm:LlamaRMSNorm",
420+
}
421+
}
422+
423+
kernel_config = KernelConfig(kernel_mapping)
424+
425+
# Create a mock model on CUDA device
426+
mock_model = MagicMock()
427+
mock_model.training = False
428+
429+
# Mock parameter with CUDA device
430+
mock_param = MagicMock()
431+
mock_param.device.type = "cuda"
432+
mock_model.parameters.return_value = iter([mock_param])
433+
434+
# Mock named_modules with RMSNorm layer
435+
mock_layer = MagicMock()
436+
mock_layer.kernel_layer_name = "RMSNorm"
437+
mock_model.named_modules.return_value = [("layers.0", mock_layer)]
438+
439+
# Trigger the mapping creation
440+
kernel_config.create_compatible_mapping(mock_model)
441+
442+
# Verify results
443+
result_mapping = kernel_config.kernel_mapping
444+
445+
self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")
446+
backends = list(result_mapping["RMSNorm"].keys())
447+
448+
# Assert only CUDA is present, not ROCm
449+
self.assertIn("cuda", backends, "CUDA backend should be registered")
450+
self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device")
451+
452+
def test_single_device_mapping_still_works(self):
453+
"""
454+
Test that single-device mappings continue to work as expected.
455+
"""
456+
kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"}
457+
458+
kernel_config = KernelConfig(kernel_mapping)
459+
460+
# Create a mock model
461+
mock_model = MagicMock()
462+
mock_model.training = False
463+
464+
mock_param = MagicMock()
465+
mock_param.device.type = "cuda"
466+
mock_model.parameters.return_value = iter([mock_param])
467+
468+
mock_layer = MagicMock()
469+
mock_layer.kernel_layer_name = "RMSNorm"
470+
mock_model.named_modules.return_value = [("layers.0", mock_layer)]
471+
kernel_config.create_compatible_mapping(mock_model)
472+
473+
result_mapping = kernel_config.kernel_mapping
474+
self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")

0 commit comments

Comments
 (0)