Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,38 +166,24 @@ func filterAutomaticDevices(devices []string) []string {
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)

perModeIdentifiers := make(map[string][]string)
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
uniqueModes := []string{"auto"}
seen := make(map[string]bool)
for _, device := range devices {
mode, id := getModeIdentifier(device)
logger.Debugf("Mapped %v to %v: %v", device, mode, id)
if !seen[mode] {
uniqueModes = append(uniqueModes, mode)
seen[mode] = true
}
if id != "" {
perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id)
}
}
cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...)

logger.Debugf("Per-mode identifiers: %v", perModeIdentifiers)
logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers)
var modifiers oci.SpecModifiers
for _, mode := range uniqueModes {
for _, mode := range cdiModeIdentifiers.modes {
cdilib, err := nvcdi.New(
nvcdi.WithLogger(logger),
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
nvcdi.WithVendor(automaticDeviceVendor),
nvcdi.WithClass(perModeDeviceClass[mode]),
nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]),
nvcdi.WithMode(mode),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
}

spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...)
spec, err := cdilib.GetSpec(cdiModeIdentifiers.idsByMode[mode]...)
if err != nil {
return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err)
}
Expand All @@ -216,6 +202,35 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
return modifiers, nil
}

type cdiModeIdentifiers struct {
modes []string
idsByMode map[string][]string
deviceClassByMode map[string]string
}

func cdiModeIdentfiersFromDevices(devices ...string) *cdiModeIdentifiers {
perModeIdentifiers := make(map[string][]string)
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
var uniqueModes []string
seen := make(map[string]bool)
for _, device := range devices {
mode, id := getModeIdentifier(device)
if !seen[mode] {
uniqueModes = append(uniqueModes, mode)
seen[mode] = true
}
if id != "" {
perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id)
}
}

return &cdiModeIdentifiers{
modes: uniqueModes,
idsByMode: perModeIdentifiers,
deviceClassByMode: perModeDeviceClass,
}
}

func getModeIdentifier(device string) (string, string) {
if !strings.HasPrefix(device, "mode=") {
return "auto", strings.TrimPrefix(device, automaticDevicePrefix)
Expand Down
83 changes: 83 additions & 0 deletions internal/modifier/cdi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,86 @@ func TestDeviceRequests(t *testing.T) {
})
}
}

func Test_cdiModeIdentfiersFromDevices(t *testing.T) {
testCases := []struct {
description string
devices []string
expected *cdiModeIdentifiers
}{
{
description: "empty device list",
devices: []string{},
expected: &cdiModeIdentifiers{
modes: nil,
idsByMode: map[string][]string{},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "single automatic device",
devices: []string{"0"},
expected: &cdiModeIdentifiers{
modes: []string{"auto"},
idsByMode: map[string][]string{"auto": {"0"}},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "multiple automatic devices",
devices: []string{"0", "1"},
expected: &cdiModeIdentifiers{
modes: []string{"auto"},
idsByMode: map[string][]string{"auto": {"0", "1"}},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "device with explicit mode",
devices: []string{"mode=gds,id=foo"},
expected: &cdiModeIdentifiers{
modes: []string{"gds"},
idsByMode: map[string][]string{"gds": {"foo"}},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "mixed auto and explicit",
devices: []string{"0", "mode=gds,id=foo", "mode=gdrcopy,id=bar"},
expected: &cdiModeIdentifiers{
modes: []string{"auto", "gds", "gdrcopy"},
idsByMode: map[string][]string{
"auto": {"0"},
"gds": {"foo"},
"gdrcopy": {"bar"},
},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "device with only mode, no id",
devices: []string{"mode=nvswitch"},
expected: &cdiModeIdentifiers{
modes: []string{"nvswitch"},
idsByMode: map[string][]string{},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
{
description: "duplicate modes",
devices: []string{"mode=gds,id=x", "mode=gds,id=y", "mode=gds"},
expected: &cdiModeIdentifiers{
modes: []string{"gds"},
idsByMode: map[string][]string{"gds": {"x", "y"}},
deviceClassByMode: map[string]string{"auto": "gpu"},
},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
result := cdiModeIdentfiersFromDevices(tc.devices...)
require.EqualValues(t, tc.expected, result)
})
}
}
10 changes: 10 additions & 0 deletions tests/e2e/nvidia-container-toolkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ var _ = Describe("docker", Ordered, ContinueOnFailure, func() {
Expect(err).ToNot(HaveOccurred())
Expect(ldconfigOut).To(ContainSubstring("/usr/local/cuda-12.9/compat/"))
})

It("should create a single ld.so.conf.d config file", func(ctx context.Context) {
lsout, _, err := runner.Run("docker run --rm -i -e NVIDIA_DISABLE_REQUIRE=true --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all nvcr.io/nvidia/cuda:12.9.0-base-ubi8 bash -c \"ls -l /etc/ld.so.conf.d/00-compat-*.conf\"")
Expect(err).ToNot(HaveOccurred())
Expect(lsout).To(WithTransform(
func(s string) []string {
return strings.Split(strings.TrimSpace(s), "\n")
}, HaveLen(1),
))
})
})

When("Disabling device node creation", Ordered, func() {
Expand Down
2 changes: 1 addition & 1 deletion third_party/libnvidia-container