Skip to content

Commit b93d4f7

Browse files
JackDoanjohnmaguirejasikpark
authored
methods for the OIDC auth flow (#25)
* stuff for the OIDC auth flow * add tests * add statuses * Apply suggestions from code review Co-authored-by: John Maguire <[email protected]> * enduser -> endpoint * Use APIError to associate request id * Remove extraneous log * LoginUrl -> LoginURL * Match client methods to endpoint naming pattern * Construct a URL directly for the pollURL * bump ancient crypto dep * fix test * nit * no logger arg, that was silly * nits * uhp * uhp2 * uhp2 * feedback * clean up tests a bit --------- Co-authored-by: John Maguire <[email protected]> Co-authored-by: Caleb Jasik <[email protected]>
1 parent caa5a20 commit b93d4f7

File tree

6 files changed

+314
-67
lines changed

6 files changed

+314
-67
lines changed

client.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,91 @@ func nonce() []byte {
581581
}
582582
return nonce
583583
}
584+
585+
func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
586+
dest, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
587+
if err != nil {
588+
return nil, err
589+
}
590+
591+
req, err := http.NewRequestWithContext(ctx, "POST", dest, nil)
592+
if err != nil {
593+
return nil, err
594+
}
595+
596+
resp, err := c.client.Do(req)
597+
if err != nil {
598+
return nil, err
599+
}
600+
defer resp.Body.Close()
601+
602+
reqID := resp.Header.Get("X-Request-ID")
603+
respBody, err := io.ReadAll(resp.Body)
604+
if err != nil {
605+
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
606+
}
607+
608+
switch resp.StatusCode {
609+
case http.StatusOK:
610+
r := message.PreAuthResponse{}
611+
if err = json.Unmarshal(respBody, &r); err != nil {
612+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
613+
}
614+
615+
if r.Data.PollToken == "" || r.Data.LoginURL == "" {
616+
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
617+
}
618+
619+
return &r.Data, nil
620+
default:
621+
var errors struct {
622+
Errors message.APIErrors
623+
}
624+
if err := json.Unmarshal(respBody, &errors); err != nil {
625+
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
626+
}
627+
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
628+
}
629+
}
630+
631+
func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
632+
pollURL, err := url.JoinPath(c.dnServer, message.EndpointAuthPoll)
633+
if err != nil {
634+
return nil, err
635+
}
636+
pollURL = fmt.Sprintf("%s?pollToken=%s", pollURL, url.QueryEscape(pollCode))
637+
638+
req, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
639+
if err != nil {
640+
return nil, err
641+
}
642+
643+
resp, err := c.client.Do(req)
644+
if err != nil {
645+
return nil, err
646+
}
647+
defer resp.Body.Close()
648+
649+
reqID := resp.Header.Get("X-Request-ID")
650+
respBody, err := io.ReadAll(resp.Body)
651+
if err != nil {
652+
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
653+
}
654+
655+
switch resp.StatusCode {
656+
case http.StatusOK:
657+
r := message.EndpointAuthPollResponse{}
658+
if err = json.Unmarshal(respBody, &r); err != nil {
659+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
660+
}
661+
return &r.Data, nil
662+
default:
663+
var errors struct {
664+
Errors message.APIErrors
665+
}
666+
if err := json.Unmarshal(respBody, &errors); err != nil {
667+
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
668+
}
669+
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
670+
}
671+
}

client_test.go

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func TestDoUpdate(t *testing.T) {
241241
assert.NotEmpty(t, pkey)
242242

243243
// Invalid request signature should return a specific error
244-
ts.ExpectRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
244+
ts.ExpectDNClientRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
245245
return []byte("")
246246
})
247247

@@ -265,7 +265,7 @@ func TestDoUpdate(t *testing.T) {
265265
require.Len(t, serverErrs, 1)
266266

267267
// Invalid signature
268-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
268+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
269269
newConfigResponse := message.DoUpdateResponse{
270270
Config: dnapitest.NebulaCfg(caPEM),
271271
Counter: 2,
@@ -320,7 +320,7 @@ func TestDoUpdate(t *testing.T) {
320320
require.Nil(t, pkey)
321321

322322
// Invalid counter
323-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
323+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
324324
newConfigResponse := message.DoUpdateResponse{
325325
Config: dnapitest.NebulaCfg(caPEM),
326326
Counter: 0,
@@ -379,7 +379,7 @@ func TestDoUpdate(t *testing.T) {
379379
hostIP := "192.168.100.1"
380380

381381
// This time sign the response with the correct CA key.
382-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
382+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
383383
newConfigResponse := message.DoUpdateResponse{
384384
Config: dnapitest.NebulaCfg(caPEM),
385385
Counter: 3,
@@ -505,7 +505,7 @@ func TestDoUpdate_P256(t *testing.T) {
505505
assert.NotEmpty(t, pkey)
506506

507507
// Invalid request signature should return a specific error
508-
ts.ExpectRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
508+
ts.ExpectDNClientRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
509509
return []byte("")
510510
})
511511

@@ -528,7 +528,7 @@ func TestDoUpdate_P256(t *testing.T) {
528528
require.Len(t, serverErrs, 1)
529529

530530
// Invalid signature
531-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
531+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
532532
newConfigResponse := message.DoUpdateResponse{
533533
Config: dnapitest.NebulaCfg(caPEM),
534534
Counter: 2,
@@ -574,7 +574,7 @@ func TestDoUpdate_P256(t *testing.T) {
574574
require.Nil(t, pkey)
575575

576576
// Invalid counter
577-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
577+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
578578
newConfigResponse := message.DoUpdateResponse{
579579
Config: dnapitest.NebulaCfg(caPEM),
580580
Counter: 0,
@@ -618,7 +618,7 @@ func TestDoUpdate_P256(t *testing.T) {
618618
require.Nil(t, pkey)
619619

620620
// This time sign the response with the correct CA key.
621-
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
621+
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
622622
newConfigResponse := message.DoUpdateResponse{
623623
Config: dnapitest.NebulaCfg(caPEM),
624624
Counter: 3,
@@ -743,7 +743,7 @@ func TestCommandResponse(t *testing.T) {
743743
// This time sign the response with the correct CA key.
744744
responseToken := "abc123"
745745
res := map[string]any{"msg": "Hello, world!"}
746-
ts.ExpectRequest(message.CommandResponse, http.StatusOK, func(r message.RequestWrapper) []byte {
746+
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusOK, func(r message.RequestWrapper) []byte {
747747
var val map[string]any
748748
err := json.Unmarshal(r.Value, &val)
749749
require.NoError(t, err)
@@ -759,7 +759,7 @@ func TestCommandResponse(t *testing.T) {
759759

760760
// Test error handling
761761
errorMsg := "sample error"
762-
ts.ExpectRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
762+
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
763763
return jsonMarshal(message.EnrollResponse{
764764
Errors: message.APIErrors{{
765765
Code: "ERR_INVALID_VALUE",
@@ -954,3 +954,87 @@ func marshalCAPublicKey(curve cert.Curve, pubkey []byte) []byte {
954954
panic("unsupported curve")
955955
}
956956
}
957+
958+
func TestGetOidcPollCode(t *testing.T) {
959+
t.Parallel()
960+
961+
useragent := "dnclientUnitTests/1.0.0 (not a real client)"
962+
ts := dnapitest.NewServer(useragent)
963+
client := NewClient(useragent, ts.URL)
964+
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
965+
t.Cleanup(func() { ts.Close() })
966+
const expectedCode = "123456"
967+
ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte {
968+
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
969+
})
970+
971+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
972+
defer cancel()
973+
resp, err := client.EndpointPreAuth(ctx)
974+
require.NoError(t, err)
975+
assert.NotNil(t, resp)
976+
assert.Equal(t, expectedCode, resp.PollToken)
977+
assert.Equal(t, "https://example.com", resp.LoginURL)
978+
assert.Empty(t, ts.Errors())
979+
assert.Equal(t, 0, ts.RequestsRemaining())
980+
981+
//unhappy path
982+
ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte {
983+
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
984+
})
985+
resp, err = client.EndpointPreAuth(ctx)
986+
require.Error(t, err)
987+
require.Nil(t, resp)
988+
assert.Empty(t, ts.Errors())
989+
assert.Equal(t, 0, ts.RequestsRemaining())
990+
}
991+
992+
func TestDoOidcPoll(t *testing.T) {
993+
t.Parallel()
994+
995+
useragent := "dnclientUnitTests/1.0.0 (not a real client)"
996+
ts := dnapitest.NewServer(useragent)
997+
client := NewClient(useragent, ts.URL)
998+
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
999+
t.Cleanup(func() { ts.Close() })
1000+
const expectedCode = "123456"
1001+
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1002+
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1003+
Status: message.EndpointAuthStarted,
1004+
EnrollmentCode: "",
1005+
}})
1006+
})
1007+
1008+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
1009+
defer cancel()
1010+
resp, err := client.EndpointAuthPoll(ctx, expectedCode)
1011+
require.NoError(t, err)
1012+
assert.Equal(t, resp.Status, message.EndpointAuthStarted)
1013+
assert.Equal(t, resp.EnrollmentCode, "")
1014+
assert.Empty(t, ts.Errors())
1015+
assert.Equal(t, 0, ts.RequestsRemaining())
1016+
1017+
//unhappy path
1018+
ts.ExpectAPIRequest(http.StatusBadRequest, func(r any) []byte {
1019+
return nil
1020+
})
1021+
resp, err = client.EndpointAuthPoll(ctx, "") //blank code should error!
1022+
require.Error(t, err)
1023+
assert.Nil(t, resp)
1024+
assert.Empty(t, ts.Errors())
1025+
assert.Equal(t, 0, ts.RequestsRemaining())
1026+
1027+
//complete path
1028+
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1029+
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1030+
Status: message.EndpointAuthCompleted,
1031+
EnrollmentCode: "deadbeef",
1032+
}})
1033+
})
1034+
resp, err = client.EndpointAuthPoll(ctx, expectedCode)
1035+
require.NoError(t, err)
1036+
assert.Equal(t, resp.Status, message.EndpointAuthCompleted)
1037+
assert.Equal(t, resp.EnrollmentCode, "deadbeef")
1038+
assert.Empty(t, ts.Errors())
1039+
assert.Equal(t, 0, ts.RequestsRemaining())
1040+
}

0 commit comments

Comments
 (0)