Skip to content

Commit 3458491

Browse files
committed
update test case
1 parent a23e594 commit 3458491

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

internal/integration/client_side_encryption_prose_test.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,11 +3177,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
31773177
return cred, nil
31783178
},
31793179
})
3180-
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
3181-
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
3182-
3183-
dkOpts := options.DataKey()
3184-
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
3180+
_, err = mongo.NewClientEncryption(keyVaultClient, ceo)
31853181
assert.Error(mt, err, "expected an error")
31863182
})
31873183
mt.Run("Case 2: ClientEncryption with credentialProviders works", func(mt *mtest.T) {
@@ -3209,7 +3205,10 @@ func TestClientSideEncryptionProse(t *testing.T) {
32093205
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
32103206
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
32113207

3212-
dkOpts := options.DataKey()
3208+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3209+
{"region", "us-east-1"},
3210+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
3211+
})
32133212
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
32143213
assert.NoErrorf(mt, err, "unexpected error %v", err)
32153214
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
@@ -3254,35 +3253,32 @@ func TestClientSideEncryptionProse(t *testing.T) {
32543253
keyVaultClient, err := mongo.Connect(opts)
32553254
assert.NoErrorf(mt, err, "error on Connect: %v", err)
32563255

3256+
var calledCount int
32573257
ceo := options.ClientEncryption().
32583258
SetKeyVaultNamespace("keyvault.datakeys").
32593259
SetKmsProviders(map[string]map[string]any{
3260-
"aws": {
3261-
"accessKeyId": awsAccessKeyID,
3262-
"secretAccessKey": awsSecretAccessKey,
3263-
},
3260+
"aws": map[string]any{},
32643261
}).
32653262
SetCredentialProviders(map[string]options.CredentialsProvider{
32663263
"aws": func(ctx context.Context) (options.Credentials, error) {
3267-
var cred options.Credentials
3268-
provider := credproviders.NewEnvProvider()
3269-
c, err := provider.Retrieve(ctx)
3270-
if err != nil {
3271-
return cred, err
3272-
}
3273-
cred.AccessKeyID = c.AccessKeyID
3274-
cred.SecretAccessKey = c.SecretAccessKey
3275-
cred.SessionToken = c.SessionToken
3276-
cred.ExpirationCallback = provider.IsExpired
3277-
return cred, nil
3264+
calledCount++
3265+
return options.Credentials{
3266+
AccessKeyID: awsAccessKeyID,
3267+
SecretAccessKey: awsSecretAccessKey,
3268+
ExpirationCallback: func() bool { return false },
3269+
}, nil
32783270
},
32793271
})
32803272
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
32813273
assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err)
32823274

3283-
dkOpts := options.DataKey()
3275+
dkOpts := options.DataKey().SetMasterKey(bson.D{
3276+
{"region", "us-east-1"},
3277+
{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
3278+
})
32843279
_, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
32853280
assert.NoErrorf(mt, err, "unexpected error %v", err)
3281+
assert.Equal(mt, 1, calledCount, "expected credential provider to be called once")
32863282
})
32873283
})
32883284
}

x/mongo/driver/mongocrypt/mongocrypt.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,16 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
6565
if needsKmsProvider(opts.KmsProviders, "gcp") {
6666
kmsProviders["gcp"] = creds.NewGCPCredentialProvider(httpClient)
6767
}
68+
provider, ok := opts.CredentialProviders["aws"]
6869
if needsKmsProvider(opts.KmsProviders, "aws") {
6970
var providers []credentials.Provider
70-
if provider, ok := opts.CredentialProviders["aws"]; ok {
71+
if ok {
7172
providers = append(providers, provider)
7273
}
7374
kmsProviders["aws"] = creds.NewAWSCredentialProvider(httpClient, providers...)
75+
} else if ok {
76+
return nil, fmt.Error("can only provide a custom AWS credential provider " +
77+
"when the state machine is configured for automatic AWS credential fetching")
7478
}
7579
if needsKmsProvider(opts.KmsProviders, "azure") {
7680
kmsProviders["azure"] = creds.NewAzureCredentialProvider(httpClient)

0 commit comments

Comments
 (0)