Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,91 @@ func nonce() []byte {
}
return nonce
}

func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
dest, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, "POST", dest, nil)
if err != nil {
return nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
}

switch resp.StatusCode {
case http.StatusOK:
r := message.PreAuthResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}

if r.Data.PollToken == "" || r.Data.LoginURL == "" {
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
}

return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
}
}

func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
pollURL, err := url.JoinPath(c.dnServer, message.EndpointAuthPoll)
if err != nil {
return nil, err
}
pollURL = fmt.Sprintf("%s?pollToken=%s", pollURL, url.QueryEscape(pollCode))

req, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
if err != nil {
return nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
}

switch resp.StatusCode {
case http.StatusOK:
r := message.EndpointAuthPollResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}
return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
}
}
104 changes: 94 additions & 10 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func TestDoUpdate(t *testing.T) {
assert.NotEmpty(t, pkey)

// Invalid request signature should return a specific error
ts.ExpectRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
return []byte("")
})

Expand All @@ -265,7 +265,7 @@ func TestDoUpdate(t *testing.T) {
require.Len(t, serverErrs, 1)

// Invalid signature
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 2,
Expand Down Expand Up @@ -320,7 +320,7 @@ func TestDoUpdate(t *testing.T) {
require.Nil(t, pkey)

// Invalid counter
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 0,
Expand Down Expand Up @@ -379,7 +379,7 @@ func TestDoUpdate(t *testing.T) {
hostIP := "192.168.100.1"

// This time sign the response with the correct CA key.
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 3,
Expand Down Expand Up @@ -505,7 +505,7 @@ func TestDoUpdate_P256(t *testing.T) {
assert.NotEmpty(t, pkey)

// Invalid request signature should return a specific error
ts.ExpectRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.CheckForUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
return []byte("")
})

Expand All @@ -528,7 +528,7 @@ func TestDoUpdate_P256(t *testing.T) {
require.Len(t, serverErrs, 1)

// Invalid signature
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 2,
Expand Down Expand Up @@ -574,7 +574,7 @@ func TestDoUpdate_P256(t *testing.T) {
require.Nil(t, pkey)

// Invalid counter
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 0,
Expand Down Expand Up @@ -618,7 +618,7 @@ func TestDoUpdate_P256(t *testing.T) {
require.Nil(t, pkey)

// This time sign the response with the correct CA key.
ts.ExpectRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 3,
Expand Down Expand Up @@ -743,7 +743,7 @@ func TestCommandResponse(t *testing.T) {
// This time sign the response with the correct CA key.
responseToken := "abc123"
res := map[string]any{"msg": "Hello, world!"}
ts.ExpectRequest(message.CommandResponse, http.StatusOK, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusOK, func(r message.RequestWrapper) []byte {
var val map[string]any
err := json.Unmarshal(r.Value, &val)
require.NoError(t, err)
Expand All @@ -759,7 +759,7 @@ func TestCommandResponse(t *testing.T) {

// Test error handling
errorMsg := "sample error"
ts.ExpectRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
return jsonMarshal(message.EnrollResponse{
Errors: message.APIErrors{{
Code: "ERR_INVALID_VALUE",
Expand Down Expand Up @@ -954,3 +954,87 @@ func marshalCAPublicKey(curve cert.Curve, pubkey []byte) []byte {
panic("unsupported curve")
}
}

func TestGetOidcPollCode(t *testing.T) {
t.Parallel()

useragent := "dnclientUnitTests/1.0.0 (not a real client)"
ts := dnapitest.NewServer(useragent)
client := NewClient(useragent, ts.URL)
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte {
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
resp, err := client.EndpointPreAuth(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, expectedCode, resp.PollToken)
assert.Equal(t, "https://example.com", resp.LoginURL)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

//unhappy path
ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte {
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
})
resp, err = client.EndpointPreAuth(ctx)
require.Error(t, err)
require.Nil(t, resp)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())
}

func TestDoOidcPoll(t *testing.T) {
t.Parallel()

useragent := "dnclientUnitTests/1.0.0 (not a real client)"
ts := dnapitest.NewServer(useragent)
client := NewClient(useragent, ts.URL)
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
Status: message.EndpointAuthStarted,
EnrollmentCode: "",
}})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
resp, err := client.EndpointAuthPoll(ctx, expectedCode)
require.NoError(t, err)
assert.Equal(t, resp.Status, message.EndpointAuthStarted)
assert.Equal(t, resp.EnrollmentCode, "")
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

//unhappy path
ts.ExpectAPIRequest(http.StatusBadRequest, func(r any) []byte {
return nil
})
resp, err = client.EndpointAuthPoll(ctx, "") //blank code should error!
require.Error(t, err)
assert.Nil(t, resp)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

//complete path
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
Status: message.EndpointAuthCompleted,
EnrollmentCode: "deadbeef",
}})
})
resp, err = client.EndpointAuthPoll(ctx, expectedCode)
require.NoError(t, err)
assert.Equal(t, resp.Status, message.EndpointAuthCompleted)
assert.Equal(t, resp.EnrollmentCode, "deadbeef")
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())
}
Loading
Loading