Skip to content

Commit 174d804

Browse files
committed
DNS caching
This commit introduces DNS caching with the -dns-ttl flag. Supersedes #576 Signed-off-by: Tomás Senart <[email protected]>
1 parent 556bf61 commit 174d804

File tree

6 files changed

+203
-44
lines changed

6 files changed

+203
-44
lines changed

attack.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ func attackCmd() command {
5656
fs.Var(&opts.laddr, "laddr", "Local IP address")
5757
fs.BoolVar(&opts.keepalive, "keepalive", true, "Use persistent connections")
5858
fs.StringVar(&opts.unixSocket, "unix-socket", "", "Connect over a unix socket. This overrides the host address in target URLs")
59+
fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]")
5960
systemSpecificFlags(fs, opts)
6061

6162
return command{fs, func(args []string) error {
@@ -99,6 +100,7 @@ type attackOpts struct {
99100
keepalive bool
100101
resolvers csl
101102
unixSocket string
103+
dnsTTL time.Duration
102104
}
103105

104106
// attack validates the attack arguments, sets up the
@@ -116,6 +118,8 @@ func attack(opts *attackOpts) (err error) {
116118
net.DefaultResolver = res
117119
}
118120

121+
net.DefaultResolver.PreferGo = true
122+
119123
files := map[string]io.Reader{}
120124
for _, filename := range []string{opts.targetsf, opts.bodyf} {
121125
if filename == "" {
@@ -188,6 +192,7 @@ func attack(opts *attackOpts) (err error) {
188192
vegeta.UnixSocket(opts.unixSocket),
189193
vegeta.ProxyHeader(proxyHdr),
190194
vegeta.ChunkedBody(opts.chunked),
195+
vegeta.DNSCaching(opts.dnsTTL),
191196
)
192197

193198
res := atk.Attack(tr, opts.rate, opts.duration, opts.name)

flags.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,24 @@ func (f *maxBodyFlag) String() string {
132132
}
133133
return datasize.ByteSize(*(f.n)).String()
134134
}
135+
136+
type dnsTTLFlag struct{ ttl *time.Duration }
137+
138+
func (f *dnsTTLFlag) Set(v string) (err error) {
139+
if v == "-1" {
140+
*(f.ttl) = -1
141+
return nil
142+
}
143+
144+
*(f.ttl), err = time.ParseDuration(v)
145+
return err
146+
}
147+
148+
func (f *dnsTTLFlag) String() string {
149+
if f.ttl == nil {
150+
return ""
151+
} else if *(f.ttl) == -1 {
152+
return "-1"
153+
}
154+
return f.ttl.String()
155+
}

go.mod

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ require (
2121
require (
2222
github.com/iancoleman/orderedmap v0.3.0 // indirect
2323
github.com/josharian/intern v1.0.0 // indirect
24+
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 // indirect
25+
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c // indirect
26+
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 // indirect
2427
golang.org/x/mod v0.8.0 // indirect
28+
golang.org/x/sync v0.1.0 // indirect
2529
golang.org/x/sys v0.10.0 // indirect
2630
golang.org/x/text v0.11.0 // indirect
2731
golang.org/x/tools v0.6.0 // indirect

go.sum

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
2626
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
2727
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
2828
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
29+
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 h1:Lt9DzQALzHoDwMBGJ6v8ObDPR0dzr2a6sXTB1Fq7IHs=
30+
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA=
31+
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c h1:aqg5Vm5dwtvL+YgDpBcK1ITf3o96N/K7/wsRXQnUTEs=
32+
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c/go.mod h1:owqhoLW1qZoYLZzLnBw+QkPP9WZnjlSWihhxAJC1+/M=
33+
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 h1:OfRzdxCzDhp+rsKWXuOO2I/quKMJ/+TQwVbIP/gltZg=
34+
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92/go.mod h1:7/OT02F6S6I7v6WXb+IjhMuZEYfH/RJ5RwEWnEo5BMg=
2935
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs=
3036
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU=
3137
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -39,7 +45,9 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
3945
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
4046
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
4147
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
48+
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
4249
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
50+
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
4351
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
4452
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
4553
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=

lib/attack.go

Lines changed: 147 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import (
77
"io"
88
"io/ioutil"
99
"math"
10+
"math/rand"
1011
"net"
1112
"net/http"
1213
"net/url"
1314
"strconv"
1415
"sync"
1516
"time"
1617

18+
"github.com/rs/dnscache"
1719
"golang.org/x/net/http2"
1820
)
1921

@@ -27,9 +29,6 @@ type Attacker struct {
2729
maxWorkers uint64
2830
maxBody int64
2931
redirects int
30-
seqmu sync.Mutex
31-
seq uint64
32-
began time.Time
3332
chunked bool
3433
}
3534

@@ -73,7 +72,6 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
7372
workers: DefaultWorkers,
7473
maxWorkers: DefaultMaxWorkers,
7574
maxBody: DefaultMaxBody,
76-
began: time.Now(),
7775
}
7876

7977
a.dialer = &net.Dialer{
@@ -85,7 +83,7 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
8583
Timeout: DefaultTimeout,
8684
Transport: &http.Transport{
8785
Proxy: http.ProxyFromEnvironment,
88-
Dial: a.dialer.Dial,
86+
DialContext: a.dialer.DialContext,
8987
TLSClientConfig: DefaultTLSConfig,
9088
MaxIdleConnsPerHost: DefaultConnections,
9189
MaxConnsPerHost: DefaultMaxConnections,
@@ -177,7 +175,7 @@ func LocalAddr(addr net.IPAddr) func(*Attacker) {
177175
return func(a *Attacker) {
178176
tr := a.client.Transport.(*http.Transport)
179177
a.dialer.LocalAddr = &net.TCPAddr{IP: addr.IP, Zone: addr.Zone}
180-
tr.Dial = a.dialer.Dial
178+
tr.DialContext = a.dialer.DialContext
181179
}
182180
}
183181

@@ -189,7 +187,7 @@ func KeepAlive(keepalive bool) func(*Attacker) {
189187
tr.DisableKeepAlives = !keepalive
190188
if !keepalive {
191189
a.dialer.KeepAlive = 0
192-
tr.Dial = a.dialer.Dial
190+
tr.DialContext = a.dialer.DialContext
193191
}
194192
}
195193
}
@@ -223,8 +221,8 @@ func H2C(enabled bool) func(*Attacker) {
223221
if tr := a.client.Transport.(*http.Transport); enabled {
224222
a.client.Transport = &http2.Transport{
225223
AllowHTTP: true,
226-
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
227-
return tr.Dial(network, addr)
224+
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
225+
return tr.DialContext(ctx, network, addr)
228226
},
229227
}
230228
}
@@ -263,6 +261,119 @@ func ProxyHeader(h http.Header) func(*Attacker) {
263261
}
264262
}
265263

264+
// DNSCaching returns a functional option that enables DNS caching for
265+
// the given ttl. When ttl is zero cached entries will never expire.
266+
// When ttl is non-zero, this will start a refresh go-routine that updates
267+
// the cache every ttl interval. This go-routine will be stopped when the
268+
// attack is stopped.
269+
// When the ttl is negative, no caching will be performed.
270+
func DNSCaching(ttl time.Duration) func(*Attacker) {
271+
return func(a *Attacker) {
272+
if ttl < 0 {
273+
return
274+
}
275+
276+
if tr, ok := a.client.Transport.(*http.Transport); ok {
277+
dial := tr.DialContext
278+
if dial == nil {
279+
dial = a.dialer.DialContext
280+
}
281+
282+
resolver := &dnscache.Resolver{}
283+
284+
if ttl != 0 {
285+
go func() {
286+
refresh := time.NewTicker(ttl)
287+
defer refresh.Stop()
288+
for {
289+
select {
290+
case <-refresh.C:
291+
resolver.Refresh(true)
292+
case <-a.stopch:
293+
return
294+
}
295+
}
296+
}()
297+
}
298+
299+
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
300+
301+
tr.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
302+
host, port, err := net.SplitHostPort(addr)
303+
if err != nil {
304+
return nil, err
305+
}
306+
307+
ips, err := resolver.LookupHost(ctx, host)
308+
if err != nil {
309+
return nil, err
310+
}
311+
312+
if len(ips) == 0 {
313+
return nil, &net.DNSError{Err: "no such host", Name: addr}
314+
}
315+
316+
// Pick a random IP from each IP family and dial each concurrently.
317+
// The first that succeeds wins, the other gets canceled.
318+
319+
rng.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] })
320+
321+
// In place filtering of ips to only include the first IPv4 and IPv6.
322+
j := 0
323+
for i := 0; i < len(ips) && j < 2; i++ {
324+
ip := net.ParseIP(ips[i])
325+
switch {
326+
case len(ip.To4()) == net.IPv4len && j == 0:
327+
fallthrough
328+
case len(ip) == net.IPv6len && j == 1:
329+
ips[j] = ips[i]
330+
j++
331+
}
332+
}
333+
ips = ips[:j]
334+
335+
type result struct {
336+
conn net.Conn
337+
err error
338+
}
339+
340+
ch := make(chan result, len(ips))
341+
ctx, cancel := context.WithCancel(ctx)
342+
defer cancel()
343+
344+
for _, ip := range ips {
345+
go func(ip string) {
346+
conn, err := dial(ctx, network, net.JoinHostPort(ip, port))
347+
ch <- result{conn, err}
348+
}(ip)
349+
}
350+
351+
for i := 0; i < cap(ch); i++ {
352+
select {
353+
case <-ctx.Done():
354+
return nil, ctx.Err()
355+
case r := <-ch:
356+
if err = r.err; err != nil {
357+
continue
358+
}
359+
return r.conn, nil
360+
}
361+
}
362+
363+
return nil, err
364+
}
365+
}
366+
}
367+
}
368+
369+
type attack struct {
370+
name string
371+
began time.Time
372+
373+
seqmu sync.Mutex
374+
seq uint64
375+
}
376+
266377
// Attack reads its Targets from the passed Targeter and attacks them at
267378
// the rate specified by the Pacer. When the duration is zero the attack
268379
// runs until Stop is called. Results are sent to the returned channel as soon
@@ -275,21 +386,29 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
275386
workers = a.maxWorkers
276387
}
277388

389+
atk := &attack{
390+
name: name,
391+
began: time.Now(),
392+
}
393+
278394
results := make(chan *Result)
279395
ticks := make(chan struct{})
280396
for i := uint64(0); i < workers; i++ {
281397
wg.Add(1)
282-
go a.attack(tr, name, &wg, ticks, results)
398+
go a.attack(tr, atk, &wg, ticks, results)
283399
}
284400

285401
go func() {
286-
defer close(results)
287-
defer wg.Wait()
288-
defer close(ticks)
289-
290-
began, count := time.Now(), uint64(0)
402+
defer func() {
403+
close(ticks)
404+
wg.Wait()
405+
close(results)
406+
a.Stop()
407+
}()
408+
409+
count := uint64(0)
291410
for {
292-
elapsed := time.Since(began)
411+
elapsed := time.Since(atk.began)
293412
if du > 0 && elapsed > du {
294413
return
295414
}
@@ -312,7 +431,7 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
312431
// all workers are blocked. start one more and try again
313432
workers++
314433
wg.Add(1)
315-
go a.attack(tr, name, &wg, ticks, results)
434+
go a.attack(tr, atk, &wg, ticks, results)
316435
}
317436
}
318437

@@ -342,25 +461,25 @@ func (a *Attacker) Stop() bool {
342461
}
343462
}
344463

345-
func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
464+
func (a *Attacker) attack(tr Targeter, atk *attack, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
346465
defer workers.Done()
347466
for range ticks {
348-
results <- a.hit(tr, name)
467+
results <- a.hit(tr, atk)
349468
}
350469
}
351470

352-
func (a *Attacker) hit(tr Targeter, name string) *Result {
471+
func (a *Attacker) hit(tr Targeter, atk *attack) *Result {
353472
var (
354-
res = Result{Attack: name}
473+
res = Result{Attack: atk.name}
355474
tgt Target
356475
err error
357476
)
358477

359-
a.seqmu.Lock()
360-
res.Timestamp = a.began.Add(time.Since(a.began))
361-
res.Seq = a.seq
362-
a.seq++
363-
a.seqmu.Unlock()
478+
atk.seqmu.Lock()
479+
res.Timestamp = atk.began.Add(time.Since(atk.began))
480+
res.Seq = atk.seq
481+
atk.seq++
482+
atk.seqmu.Unlock()
364483

365484
defer func() {
366485
res.Latency = time.Since(res.Timestamp)
@@ -382,8 +501,8 @@ func (a *Attacker) hit(tr Targeter, name string) *Result {
382501
return &res
383502
}
384503

385-
if name != "" {
386-
req.Header.Set("X-Vegeta-Attack", name)
504+
if atk.name != "" {
505+
req.Header.Set("X-Vegeta-Attack", atk.name)
387506
}
388507

389508
req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10))

0 commit comments

Comments
 (0)