@@ -241,7 +241,7 @@ func TestDoUpdate(t *testing.T) {
241241 assert .NotEmpty (t , pkey )
242242
243243 // Invalid request signature should return a specific error
244- ts .ExpectRequest (message .CheckForUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
244+ ts .ExpectDNClientRequest (message .CheckForUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
245245 return []byte ("" )
246246 })
247247
@@ -265,7 +265,7 @@ func TestDoUpdate(t *testing.T) {
265265 require .Len (t , serverErrs , 1 )
266266
267267 // Invalid signature
268- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
268+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
269269 newConfigResponse := message.DoUpdateResponse {
270270 Config : dnapitest .NebulaCfg (caPEM ),
271271 Counter : 2 ,
@@ -320,7 +320,7 @@ func TestDoUpdate(t *testing.T) {
320320 require .Nil (t , pkey )
321321
322322 // Invalid counter
323- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
323+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
324324 newConfigResponse := message.DoUpdateResponse {
325325 Config : dnapitest .NebulaCfg (caPEM ),
326326 Counter : 0 ,
@@ -379,7 +379,7 @@ func TestDoUpdate(t *testing.T) {
379379 hostIP := "192.168.100.1"
380380
381381 // This time sign the response with the correct CA key.
382- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
382+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
383383 newConfigResponse := message.DoUpdateResponse {
384384 Config : dnapitest .NebulaCfg (caPEM ),
385385 Counter : 3 ,
@@ -505,7 +505,7 @@ func TestDoUpdate_P256(t *testing.T) {
505505 assert .NotEmpty (t , pkey )
506506
507507 // Invalid request signature should return a specific error
508- ts .ExpectRequest (message .CheckForUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
508+ ts .ExpectDNClientRequest (message .CheckForUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
509509 return []byte ("" )
510510 })
511511
@@ -528,7 +528,7 @@ func TestDoUpdate_P256(t *testing.T) {
528528 require .Len (t , serverErrs , 1 )
529529
530530 // Invalid signature
531- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
531+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
532532 newConfigResponse := message.DoUpdateResponse {
533533 Config : dnapitest .NebulaCfg (caPEM ),
534534 Counter : 2 ,
@@ -574,7 +574,7 @@ func TestDoUpdate_P256(t *testing.T) {
574574 require .Nil (t , pkey )
575575
576576 // Invalid counter
577- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
577+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
578578 newConfigResponse := message.DoUpdateResponse {
579579 Config : dnapitest .NebulaCfg (caPEM ),
580580 Counter : 0 ,
@@ -618,7 +618,7 @@ func TestDoUpdate_P256(t *testing.T) {
618618 require .Nil (t , pkey )
619619
620620 // This time sign the response with the correct CA key.
621- ts .ExpectRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
621+ ts .ExpectDNClientRequest (message .DoUpdate , http .StatusOK , func (r message.RequestWrapper ) []byte {
622622 newConfigResponse := message.DoUpdateResponse {
623623 Config : dnapitest .NebulaCfg (caPEM ),
624624 Counter : 3 ,
@@ -743,7 +743,7 @@ func TestCommandResponse(t *testing.T) {
743743 // This time sign the response with the correct CA key.
744744 responseToken := "abc123"
745745 res := map [string ]any {"msg" : "Hello, world!" }
746- ts .ExpectRequest (message .CommandResponse , http .StatusOK , func (r message.RequestWrapper ) []byte {
746+ ts .ExpectDNClientRequest (message .CommandResponse , http .StatusOK , func (r message.RequestWrapper ) []byte {
747747 var val map [string ]any
748748 err := json .Unmarshal (r .Value , & val )
749749 require .NoError (t , err )
@@ -759,7 +759,7 @@ func TestCommandResponse(t *testing.T) {
759759
760760 // Test error handling
761761 errorMsg := "sample error"
762- ts .ExpectRequest (message .CommandResponse , http .StatusBadRequest , func (r message.RequestWrapper ) []byte {
762+ ts .ExpectDNClientRequest (message .CommandResponse , http .StatusBadRequest , func (r message.RequestWrapper ) []byte {
763763 return jsonMarshal (message.EnrollResponse {
764764 Errors : message.APIErrors {{
765765 Code : "ERR_INVALID_VALUE" ,
@@ -954,3 +954,87 @@ func marshalCAPublicKey(curve cert.Curve, pubkey []byte) []byte {
954954 panic ("unsupported curve" )
955955 }
956956}
957+
958+ func TestGetOidcPollCode (t * testing.T ) {
959+ t .Parallel ()
960+
961+ useragent := "dnclientUnitTests/1.0.0 (not a real client)"
962+ ts := dnapitest .NewServer (useragent )
963+ client := NewClient (useragent , ts .URL )
964+ // attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
965+ t .Cleanup (func () { ts .Close () })
966+ const expectedCode = "123456"
967+ ts .ExpectAPIRequest (http .StatusOK , func (req any ) []byte {
968+ return jsonMarshal (message.PreAuthResponse {Data : message.PreAuthData {PollToken : expectedCode , LoginURL : "https://example.com" }})
969+ })
970+
971+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
972+ defer cancel ()
973+ resp , err := client .EndpointPreAuth (ctx )
974+ require .NoError (t , err )
975+ assert .NotNil (t , resp )
976+ assert .Equal (t , expectedCode , resp .PollToken )
977+ assert .Equal (t , "https://example.com" , resp .LoginURL )
978+ assert .Empty (t , ts .Errors ())
979+ assert .Equal (t , 0 , ts .RequestsRemaining ())
980+
981+ //unhappy path
982+ ts .ExpectAPIRequest (http .StatusBadGateway , func (req any ) []byte {
983+ return jsonMarshal (message.PreAuthResponse {Data : message.PreAuthData {PollToken : expectedCode , LoginURL : "https://example.com" }})
984+ })
985+ resp , err = client .EndpointPreAuth (ctx )
986+ require .Error (t , err )
987+ require .Nil (t , resp )
988+ assert .Empty (t , ts .Errors ())
989+ assert .Equal (t , 0 , ts .RequestsRemaining ())
990+ }
991+
992+ func TestDoOidcPoll (t * testing.T ) {
993+ t .Parallel ()
994+
995+ useragent := "dnclientUnitTests/1.0.0 (not a real client)"
996+ ts := dnapitest .NewServer (useragent )
997+ client := NewClient (useragent , ts .URL )
998+ // attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
999+ t .Cleanup (func () { ts .Close () })
1000+ const expectedCode = "123456"
1001+ ts .ExpectAPIRequest (http .StatusOK , func (r any ) []byte {
1002+ return jsonMarshal (message.EndpointAuthPollResponse {Data : message.EndpointAuthPollData {
1003+ Status : message .EndpointAuthStarted ,
1004+ EnrollmentCode : "" ,
1005+ }})
1006+ })
1007+
1008+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
1009+ defer cancel ()
1010+ resp , err := client .EndpointAuthPoll (ctx , expectedCode )
1011+ require .NoError (t , err )
1012+ assert .Equal (t , resp .Status , message .EndpointAuthStarted )
1013+ assert .Equal (t , resp .EnrollmentCode , "" )
1014+ assert .Empty (t , ts .Errors ())
1015+ assert .Equal (t , 0 , ts .RequestsRemaining ())
1016+
1017+ //unhappy path
1018+ ts .ExpectAPIRequest (http .StatusBadRequest , func (r any ) []byte {
1019+ return nil
1020+ })
1021+ resp , err = client .EndpointAuthPoll (ctx , "" ) //blank code should error!
1022+ require .Error (t , err )
1023+ assert .Nil (t , resp )
1024+ assert .Empty (t , ts .Errors ())
1025+ assert .Equal (t , 0 , ts .RequestsRemaining ())
1026+
1027+ //complete path
1028+ ts .ExpectAPIRequest (http .StatusOK , func (r any ) []byte {
1029+ return jsonMarshal (message.EndpointAuthPollResponse {Data : message.EndpointAuthPollData {
1030+ Status : message .EndpointAuthCompleted ,
1031+ EnrollmentCode : "deadbeef" ,
1032+ }})
1033+ })
1034+ resp , err = client .EndpointAuthPoll (ctx , expectedCode )
1035+ require .NoError (t , err )
1036+ assert .Equal (t , resp .Status , message .EndpointAuthCompleted )
1037+ assert .Equal (t , resp .EnrollmentCode , "deadbeef" )
1038+ assert .Empty (t , ts .Errors ())
1039+ assert .Equal (t , 0 , ts .RequestsRemaining ())
1040+ }
0 commit comments