Skip to content

Commit 6c8891d

Browse files
Merge remote-tracking branch 'origin/main' into add-vitess
2 parents 24f7f72 + 8a445fa commit 6c8891d

File tree

8 files changed

+407
-365
lines changed

8 files changed

+407
-365
lines changed

cmd/bad-key-revoker/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ func TestFindUnrevoked(t *testing.T) {
247247
test.AssertNotError(t, err, "findUnrevoked failed")
248248
test.AssertEquals(t, len(rows), 1)
249249
test.AssertEquals(t, rows[0].Serial, "ff")
250-
test.AssertEquals(t, rows[0].RegistrationID, int64(1))
250+
test.AssertEquals(t, rows[0].RegistrationID, regID)
251251
test.AssertByteEquals(t, rows[0].DER, []byte{1, 2, 3})
252252

253253
bkr.maxRevocations = 0

ctpolicy/ctpolicy.go

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ func New(pub pubpb.PublisherClient, sctLogs loglist.List, infoLogs loglist.List,
6363
}
6464
}
6565

66+
// Stagger must be positive for time.Ticker.
67+
// Default to the relatively safe value of 1 second.
68+
if stagger <= 0 {
69+
stagger = time.Second
70+
}
71+
6672
return &CTPolicy{
6773
pub: pub,
6874
sctLogs: sctLogs,
@@ -81,6 +87,22 @@ type result struct {
8187
err error
8288
}
8389

90+
// getOne obtains an SCT (or error), and returns it in resChan
91+
func (ctp *CTPolicy) getOne(ctx context.Context, cert core.CertDER, l loglist.Log, resChan chan result) {
92+
sct, err := ctp.pub.SubmitToSingleCTWithResult(ctx, &pubpb.Request{
93+
LogURL: l.Url,
94+
LogPublicKey: base64.StdEncoding.EncodeToString(l.Key),
95+
Der: cert,
96+
Kind: pubpb.SubmissionType_sct,
97+
})
98+
if err != nil {
99+
resChan <- result{log: l, err: fmt.Errorf("ct submission to %q (%q) failed: %w", l.Name, l.Url, err)}
100+
return
101+
}
102+
103+
resChan <- result{log: l, sct: sct.Sct}
104+
}
105+
84106
// GetSCTs retrieves exactly two SCTs from the total collection of configured
85107
// log groups, with at most one SCT coming from each group. It expects that all
86108
// logs run by a single operator (e.g. Google) are in the same group, to
@@ -93,69 +115,67 @@ func (ctp *CTPolicy) GetSCTs(ctx context.Context, cert core.CertDER, expiration
93115
subCtx, cancel := context.WithCancel(ctx)
94116
defer cancel()
95117

96-
// This closure will be called in parallel once for each log.
97-
getOne := func(i int, l loglist.Log) ([]byte, error) {
98-
// Sleep a little bit to stagger our requests to the later logs. Use `i-1`
99-
// to compute the stagger duration so that the first two logs (indices 0
100-
// and 1) get negative or zero (i.e. instant) sleep durations. If the
101-
// context gets cancelled (most likely because we got enough SCTs from other
102-
// logs already) before the sleep is complete, quit instead.
103-
select {
104-
case <-subCtx.Done():
105-
return nil, subCtx.Err()
106-
case <-time.After(time.Duration(i-1) * ctp.stagger):
107-
}
108-
109-
sct, err := ctp.pub.SubmitToSingleCTWithResult(ctx, &pubpb.Request{
110-
LogURL: l.Url,
111-
LogPublicKey: base64.StdEncoding.EncodeToString(l.Key),
112-
Der: cert,
113-
Kind: pubpb.SubmissionType_sct,
114-
})
115-
if err != nil {
116-
return nil, fmt.Errorf("ct submission to %q (%q) failed: %w", l.Name, l.Url, err)
117-
}
118-
119-
return sct.Sct, nil
120-
}
121-
122118
// Identify the set of candidate logs whose temporal interval includes this
123119
// cert's expiry. Randomize the order of the logs so that we're not always
124120
// trying to submit to the same two.
125121
logs := ctp.sctLogs.ForTime(expiration).Permute()
122+
if len(logs) < 2 {
123+
return nil, berrors.MissingSCTsError("Insufficient CT logs available (%d)", len(logs))
124+
}
126125

127-
// Kick off a collection of goroutines to try to submit the precert to each
128-
// log. Ensure that the results channel has a buffer equal to the number of
126+
// Ensure that the results channel has a buffer equal to the number of
129127
// goroutines we're kicking off, so that they're all guaranteed to be able to
130128
// write to it and exit without blocking and leaking.
131129
resChan := make(chan result, len(logs))
132-
for i, log := range logs {
133-
go func(i int, l loglist.Log) {
134-
sctDER, err := getOne(i, l)
135-
resChan <- result{log: l, sct: sctDER, err: err}
136-
}(i, log)
130+
131+
// Kick off first two submissions
132+
nextLog := 0
133+
for ; nextLog < 2; nextLog++ {
134+
go ctp.getOne(subCtx, cert, logs[nextLog], resChan)
137135
}
138136

139137
go ctp.submitPrecertInformational(cert, expiration)
140138

141-
// Finally, collect SCTs and/or errors from our results channel. We know that
142-
// we can collect len(logs) results from the channel because every goroutine
143-
// is guaranteed to write one result (either sct or error) to the channel.
139+
// staggerTicker will be used to start a new submission each stagger interval
140+
staggerTicker := time.NewTicker(ctp.stagger)
141+
defer staggerTicker.Stop()
142+
143+
// Collect SCTs and errors out of the results channels into these slices.
144144
results := make([]result, 0)
145145
errs := make([]string, 0)
146-
for range len(logs) {
147-
res := <-resChan
148-
if res.err != nil {
149-
errs = append(errs, res.err.Error())
150-
ctp.winnerCounter.WithLabelValues(res.log.Url, failed).Inc()
151-
continue
152-
}
153-
results = append(results, res)
154-
ctp.winnerCounter.WithLabelValues(res.log.Url, succeeded).Inc()
155146

156-
scts := compliantSet(results)
157-
if scts != nil {
158-
return scts, nil
147+
loop:
148+
for {
149+
select {
150+
case <-staggerTicker.C:
151+
// Each tick from the staggerTicker, we start submitting to another log
152+
if nextLog >= len(logs) {
153+
// Unless we have run out of logs to submit to, so don't need to tick anymore
154+
staggerTicker.Stop()
155+
continue
156+
}
157+
go ctp.getOne(subCtx, cert, logs[nextLog], resChan)
158+
nextLog++
159+
case res := <-resChan:
160+
if res.err != nil {
161+
errs = append(errs, res.err.Error())
162+
ctp.winnerCounter.WithLabelValues(res.log.Url, failed).Inc()
163+
} else {
164+
results = append(results, res)
165+
ctp.winnerCounter.WithLabelValues(res.log.Url, succeeded).Inc()
166+
167+
scts := compliantSet(results)
168+
if scts != nil {
169+
return scts, nil
170+
}
171+
}
172+
173+
// We can collect len(logs) results from the channel as every goroutine is
174+
// guaranteed to write one result (either sct or error) to the channel.
175+
if len(results)+len(errs) >= len(logs) {
176+
// We have an error or result from every log, but didn't find a compliant set
177+
break loop
178+
}
159179
}
160180
}
161181

ctpolicy/ctpolicy_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ func TestGetSCTsFailMetrics(t *testing.T) {
138138
// Ensure the proper metrics are incremented when GetSCTs fails.
139139
ctp := New(&mockFailOnePub{badURL: "UrlA1"}, loglist.List{
140140
{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
141+
{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
141142
}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
142143
_, err := ctp.GetSCTs(context.Background(), []byte{0}, time.Time{})
143144
test.AssertError(t, err, "GetSCTs should have failed")
@@ -150,6 +151,7 @@ func TestGetSCTsFailMetrics(t *testing.T) {
150151

151152
ctp = New(&mockSlowPub{}, loglist.List{
152153
{Name: "LogA1", Operator: "OperA", Url: "UrlA1", Key: []byte("KeyA1")},
154+
{Name: "LogA2", Operator: "OperA", Url: "UrlA2", Key: []byte("KeyA2")},
153155
}, nil, nil, 0, blog.NewMock(), metrics.NoopRegisterer)
154156
_, err = ctp.GetSCTs(ctx, []byte{0}, time.Time{})
155157
test.AssertError(t, err, "GetSCTs should have timed out")

0 commit comments

Comments
 (0)