diff --git a/pkg/blob/blob.go b/pkg/blob/blob.go index ec7d97bc0..698ed75f8 100644 --- a/pkg/blob/blob.go +++ b/pkg/blob/blob.go @@ -590,27 +590,45 @@ func (d *Driver) GetAuthEnv(ctx context.Context, volumeID, protocol string, attr tenantID = d.cloud.TenantID } - if clientID != "" { - if mountWithWIToken { - klog.V(2).Infof("clientID(%s) is specified, use workload identity for blobfuse auth", clientID) - - workloadIdentityToken, err := parseServiceAccountToken(serviceAccountToken) - if err != nil { - return rgName, accountName, accountKey, containerName, authEnv, err + if mountWithWIToken { + if clientID == "" { + clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID + if clientID == "" { + return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("mountWithWorkloadIdentityToken is true but clientID is not specified") } - azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, clientID+accountName) + } + klog.V(2).Infof("mountWithWorkloadIdentityToken is specified, use workload identity auth for mount, clientID: %s, tenantID: %s", clientID, tenantID) + + workloadIdentityToken, err := parseServiceAccountToken(serviceAccountToken) + if err != nil { + return rgName, accountName, accountKey, containerName, authEnv, err + } + tokenFileName := clientID + "-" + accountName + if !isValidTokenFileName(tokenFileName) { + return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("the generated token file name %s is invalid", tokenFileName) + } + azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, tokenFileName) + // check whether token value is the same as the one in the token file + existingToken, readErr := os.ReadFile(azureOAuthTokenFile) + if readErr == nil && string(existingToken) == workloadIdentityToken { + klog.V(6).Infof("the existing workload identity token file %s is up-to-date, no need to rewrite", azureOAuthTokenFile) + } else { + // write the token to a file if err := os.WriteFile(azureOAuthTokenFile, []byte(workloadIdentityToken), 0600); err != nil { return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("failed to write workload identity token file %s: %v", azureOAuthTokenFile, err) } + } - authEnv = append(authEnv, "AZURE_STORAGE_SPN_CLIENT_ID="+clientID) - if tenantID != "" { - authEnv = append(authEnv, "AZURE_STORAGE_SPN_TENANT_ID="+tenantID) - } - authEnv = append(authEnv, "AZURE_OAUTH_TOKEN_FILE="+azureOAuthTokenFile) - klog.V(2).Infof("workload identity auth: %v", authEnv) - return rgName, accountName, accountKey, containerName, authEnv, err + authEnv = append(authEnv, "AZURE_STORAGE_SPN_CLIENT_ID="+clientID) + if tenantID != "" { + authEnv = append(authEnv, "AZURE_STORAGE_SPN_TENANT_ID="+tenantID) } + authEnv = append(authEnv, "AZURE_OAUTH_TOKEN_FILE="+azureOAuthTokenFile) + klog.V(2).Infof("workload identity auth: %v", authEnv) + return rgName, accountName, accountKey, containerName, authEnv, err + } + + if clientID != "" { klog.V(2).Infof("clientID(%s) is specified, use service account token to get account key", clientID) if subsID == "" { subsID = d.cloud.SubscriptionID @@ -1244,3 +1262,20 @@ func parseServiceAccountToken(tokenStr string) (string, error) { } return token.APIAzureADTokenExchange.Token, nil } + +// isValidTokenFileName checks if the token file name is valid +// fileName should only contain alphanumeric characters, hyphens +func isValidTokenFileName(fileName string) bool { + if fileName == "" { + return false + } + for _, c := range fileName { + if !(('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + ('0' <= c && c <= '9') || + (c == '-')) { + return false + } + } + return true +} diff --git a/pkg/blob/blob_test.go b/pkg/blob/blob_test.go index 64ed50be9..ce8d26eea 100644 --- a/pkg/blob/blob_test.go +++ b/pkg/blob/blob_test.go @@ -2132,3 +2132,65 @@ func TestIsSupportedPublicNetworkAccess(t *testing.T) { } } } + +func TestIsValidTokenFileName(t *testing.T) { + testCases := []struct { + name string + fileName string + expected bool + }{ + { + name: "valid lowercase", + fileName: "token", + expected: true, + }, + { + name: "valid uppercase", + fileName: "TOKEN", + expected: true, + }, + { + name: "valid mixed alphanumeric with hyphen", + fileName: "Token-123", + expected: true, + }, + { + name: "valid mixed alphanumeric with hyphen#2", + fileName: "0ab48765-efce-4799-8a9c-c3e1de2ee42eg", + expected: true, + }, + { + name: "empty string", + fileName: "", + expected: false, + }, + { + name: "contains underscore", + fileName: "token_file", + expected: false, + }, + { + name: "contains dot", + fileName: "token.file", + expected: false, + }, + { + name: "contains space", + fileName: "token file", + expected: false, + }, + { + name: "contains slash", + fileName: "token/file", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := isValidTokenFileName(tc.fileName); got != tc.expected { + t.Fatalf("isValidTokenFileName(%q) = %t, want %t", tc.fileName, got, tc.expected) + } + }) + } +} diff --git a/pkg/blob/nodeserver.go b/pkg/blob/nodeserver.go index 9106b8cd8..e9a7b82ef 100644 --- a/pkg/blob/nodeserver.go +++ b/pkg/blob/nodeserver.go @@ -80,7 +80,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu context := req.GetVolumeContext() if context != nil { // token request - if context[serviceAccountTokenField] != "" && getValueInMap(context, clientIDField) != "" { + if context[serviceAccountTokenField] != "" && useWorkloadIdentity(context) { klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getValueInMap(context, clientIDField)) _, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{ StagingTargetPath: target, @@ -261,7 +261,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe attrib := req.GetVolumeContext() secrets := req.GetSecrets() - if getValueInMap(attrib, clientIDField) != "" && attrib[serviceAccountTokenField] == "" { + if useWorkloadIdentity(attrib) && attrib[serviceAccountTokenField] == "" { klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getValueInMap(attrib, clientIDField)) return &csi.NodeStageVolumeResponse{}, nil } @@ -733,3 +733,11 @@ func checkGidPresentInMountFlags(mountFlags []string) bool { } return false } + +// useWorkloadIdentity checks whether workload identity is used based on the presence of clientID or mountWithWIToken in volume attributes +func useWorkloadIdentity(attrib map[string]string) bool { + if getValueInMap(attrib, clientIDField) != "" || getValueInMap(attrib, mountWithWITokenField) == trueValue { + return true + } + return false +} diff --git a/pkg/blob/nodeserver_test.go b/pkg/blob/nodeserver_test.go index f9487660a..c9974cf11 100644 --- a/pkg/blob/nodeserver_test.go +++ b/pkg/blob/nodeserver_test.go @@ -1192,3 +1192,46 @@ func TestCheckGidPresentInMountFlags(t *testing.T) { } } } + +func TestUseWorkloadIdentity(t *testing.T) { + tests := []struct { + name string + attrib map[string]string + want bool + }{ + { + name: "clientID present", + attrib: map[string]string{ + clientIDField: "client-id", + }, + want: true, + }, + { + name: "mountWithWIToken true", + attrib: map[string]string{ + mountWithWITokenField: trueValue, + }, + want: true, + }, + { + name: "mountWithWIToken false", + attrib: map[string]string{ + mountWithWITokenField: "false", + }, + want: false, + }, + { + name: "no workload identity fields", + attrib: map[string]string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := useWorkloadIdentity(tt.attrib); got != tt.want { + t.Errorf("useWorkloadIdentity() = %v, want %v", got, tt.want) + } + }) + } +}