Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
return sc, nil
}

func (c *Client) handleBody(resp *http.Response) ([]byte, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read the response body: %s", err)
}

switch resp.StatusCode {
case http.StatusOK:
return respBody, nil
case http.StatusUnauthorized:
return nil, ErrInvalidCredentials
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking closer, and I think this isn't quite right. This code comes from postDNClient which is used to contact the /v1/dnclient endpoint (hence the "dnclient endpoint" in the error.) This endpoint differs from our primary API:

  • It uses a different authentication method, for dnclient
  • It doesn't use our standard response format (Data vs. Errors)
  • It has a couple envelopes, which are base64-encoded and signed

Take a look at the Enroll call for how we want to handle this - primarily, we attempt to deconstruct the Errors array.

I think that ultimately we want to create a callAPI function, similar to postDNClient, using Enroll as a basis for how to process the message , that is used for enroll, poll, and preauth.

Probably we can remove Errors from the message structs for these calls if we process them in one place like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough. I really would like to land this, so I'm choosing to punt on DRY-ing handling bodies for now if that's cool with you

}
return nil, errors.Errors.ToError()
}
}

// postDNClient wraps and signs the given dnclientRequestWrapper message, and makes the API call.
// On success, it returns the response message body. On error, the error is returned.
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey keys.PrivateKey) ([]byte, error) {
Expand All @@ -489,25 +511,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read the response body: %s", err)
}

switch resp.StatusCode {
case http.StatusOK:
return respBody, nil
case http.StatusUnauthorized:
return nil, ErrInvalidCredentials
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, errors.Errors.ToError()
}
return c.handleBody(resp)
}

// StreamController is used for interacting with streaming requests to the API.
Expand Down Expand Up @@ -581,3 +585,80 @@ func nonce() []byte {
}
return nonce
}

func (c *Client) GetOidcPollCode(ctx context.Context, logger logrus.FieldLogger) (string, error) {
logger.WithFields(logrus.Fields{"server": c.dnServer}).Debug("Making GetOidcPollCode request to API")

enrollURL, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
if err != nil {
return "", err
}

req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, nil)
if err != nil {
return "", err
}

resp, err := c.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

// Log the request ID returned from the server
reqID := resp.Header.Get("X-Request-ID")
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
b, err := c.handleBody(resp)
if err != nil {
l.Error(err) //todo I don't like erroring and also logging?
return "", err
}

// Decode the response
r := message.PreAuthResponse{}
if err = json.Unmarshal(b, &r); err != nil {
return "", &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

return r.PollToken, nil
}

func (c *Client) DoOidcPoll(ctx context.Context, logger logrus.FieldLogger, pollCode string) (*message.EnduserAuthPollResponse, error) {
logger.WithFields(logrus.Fields{"server": c.dnServer}).Debug("Making DoOidcPoll request to API")

enrollURL, err := url.JoinPath(c.dnServer, message.EnduserAuthPoll)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, "GET", enrollURL, nil)
if err != nil {
return nil, err
}
q := req.URL.Query()
q.Add("token", pollCode)
req.URL.RawQuery = q.Encode()

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

// Log the request ID returned from the server
reqID := resp.Header.Get("X-Request-ID")
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
b, err := c.handleBody(resp)
if err != nil {
l.Error(err) //todo I don't like erroring and also logging?
return nil, err
}

// Decode the response
r := message.EnduserAuthPollResponse{}
if err = json.Unmarshal(b, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

return &r, nil
}
70 changes: 70 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,3 +954,73 @@ func marshalCAPublicKey(curve cert.Curve, pubkey []byte) []byte {
panic("unsupported curve")
}
}

func TestGetOidcPollCode(t *testing.T) {
t.Parallel()

useragent := "dnclientUnitTests/1.0.0 (not a real client)"
ts := dnapitest.NewServer(useragent)
client := NewClient(useragent, ts.URL)
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectRequest(message.PreAuthEndpoint, http.StatusOK, func(req message.RequestWrapper) []byte {
return jsonMarshal(message.PreAuthResponse{PollToken: expectedCode})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
code, err := client.GetOidcPollCode(ctx, testutil.NewTestLogger())
require.NoError(t, err)
assert.Equal(t, expectedCode, code)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

//unhappy path
ts.ExpectRequest(message.PreAuthEndpoint, http.StatusBadGateway, func(req message.RequestWrapper) []byte {
return jsonMarshal(message.PreAuthResponse{PollToken: expectedCode})
})
code, err = client.GetOidcPollCode(ctx, testutil.NewTestLogger())
require.Error(t, err)
assert.Equal(t, "", code)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())
}

func TestDoOidcPoll(t *testing.T) {
t.Parallel()

useragent := "dnclientUnitTests/1.0.0 (not a real client)"
ts := dnapitest.NewServer(useragent)
client := NewClient(useragent, ts.URL)
// attempting to defer ts.Close() will trigger early due to parallel testing - use T.Cleanup instead
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectRequest(message.EnduserAuthPoll, http.StatusOK, func(req message.RequestWrapper) []byte {
return jsonMarshal(message.EnduserAuthPollResponse{
Status: "something",
LoginUrl: "https://login.example.com",
EnrollmentCode: "",
})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
resp, err := client.DoOidcPoll(ctx, testutil.NewTestLogger(), expectedCode)
require.NoError(t, err)
assert.Equal(t, resp.Status, "something")
assert.Equal(t, resp.LoginUrl, "https://login.example.com")
assert.Equal(t, resp.EnrollmentCode, "")
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

//unhappy path
ts.ExpectRequest(message.EnduserAuthPoll, http.StatusBadRequest, func(req message.RequestWrapper) []byte {
return nil
})
resp, err = client.DoOidcPoll(ctx, testutil.NewTestLogger(), "") //blank code should error!
require.Error(t, err)
assert.Nil(t, resp)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())
}
31 changes: 31 additions & 0 deletions dnapitest/dnapitest.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
s.handlerEnroll(w, r)
case message.EndpointV1:
s.handlerDNClient(w, r)
case message.PreAuthEndpoint:
expected := s.expectedRequests[0]
s.expectedRequests = s.expectedRequests[1:]
res := expected.dncRequestResponse
w.WriteHeader(res.statusCode)
_, _ = w.Write(res.response(message.RequestWrapper{}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look correct to me. This is for a /v1/dnclient endpoint message.

Doesn't preauth need to return a pollToken in order for the poll endpoint to behave correctly? I think we should add a test that the handler receives the poll token sent by the client.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does! The test injects the "desired response" from the server, and confirms the client gets it. There's no "real handler" because PreAuth doesn't take arguments from the client.

case message.EnduserAuthPoll:
s.handlerDoOidcPoll(w, r)
default:
s.errors = append(s.errors, fmt.Errorf("invalid request path %s", r.URL.Path))
http.NotFound(w, r)
Expand Down Expand Up @@ -152,6 +160,29 @@ func (s *Server) SetP256Pubkey(p256PubkeyPEM []byte) error {
return nil
}

func (s *Server) handlerDoOidcPoll(w http.ResponseWriter, r *http.Request) {
// Get the test case to validate
expected := s.expectedRequests[0]
s.expectedRequests = s.expectedRequests[1:]
if !expected.dnclientAPI {
s.errors = append(s.errors, fmt.Errorf("unexpected dnclient API request - expected enrollment request"))
http.Error(w, "unexpected dnclient API request", http.StatusInternalServerError)
return
}
res := expected.dncRequestResponse

token := r.URL.Query()["token"]
if len(token) == 0 {
s.errors = append(s.errors, fmt.Errorf("missing token"))
http.Error(w, "missing token", http.StatusBadRequest)
return
}

// return the associated response
w.WriteHeader(res.statusCode)
w.Write(res.response(message.RequestWrapper{}))
}

func (s *Server) handlerDNClient(w http.ResponseWriter, r *http.Request) {
// Get the test case to validate
expected := s.expectedRequests[0]
Expand Down
18 changes: 18 additions & 0 deletions message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,21 @@ func (nc *NetworkCurve) UnmarshalJSON(b []byte) error {

return nil
}

const PreAuthEndpoint = "/v1/enduser-auth/preauth"

type PreAuthResponse struct {
PollToken string `json:"pollToken"`
}

const EnduserAuthPoll = "/v1/enduser-auth/poll"

const EnduserAuthPollStatusNotStarted = "NOT_STARTED"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need a failure state here? Or would the connection simply close in that case?

const EnduserAuthPollStatusStarted = "OIDC_STARTED"
const EnduserAuthPollStatusSuccess = "SUCCESS"

type EnduserAuthPollResponse struct {
Status string `json:"status"`
LoginUrl string `json:"loginUrl"`
EnrollmentCode string `json:"enrollmentCode"`
}
Loading