@@ -7,13 +7,15 @@ import (
77 "io"
88 "io/ioutil"
99 "math"
10+ "math/rand"
1011 "net"
1112 "net/http"
1213 "net/url"
1314 "strconv"
1415 "sync"
1516 "time"
1617
18+ "github.com/rs/dnscache"
1719 "golang.org/x/net/http2"
1820)
1921
@@ -27,9 +29,6 @@ type Attacker struct {
2729 maxWorkers uint64
2830 maxBody int64
2931 redirects int
30- seqmu sync.Mutex
31- seq uint64
32- began time.Time
3332 chunked bool
3433}
3534
@@ -73,7 +72,6 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
7372 workers : DefaultWorkers ,
7473 maxWorkers : DefaultMaxWorkers ,
7574 maxBody : DefaultMaxBody ,
76- began : time .Now (),
7775 }
7876
7977 a .dialer = & net.Dialer {
@@ -85,7 +83,7 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
8583 Timeout : DefaultTimeout ,
8684 Transport : & http.Transport {
8785 Proxy : http .ProxyFromEnvironment ,
88- Dial : a .dialer .Dial ,
86+ DialContext : a .dialer .DialContext ,
8987 TLSClientConfig : DefaultTLSConfig ,
9088 MaxIdleConnsPerHost : DefaultConnections ,
9189 MaxConnsPerHost : DefaultMaxConnections ,
@@ -177,7 +175,7 @@ func LocalAddr(addr net.IPAddr) func(*Attacker) {
177175 return func (a * Attacker ) {
178176 tr := a .client .Transport .(* http.Transport )
179177 a .dialer .LocalAddr = & net.TCPAddr {IP : addr .IP , Zone : addr .Zone }
180- tr .Dial = a .dialer .Dial
178+ tr .DialContext = a .dialer .DialContext
181179 }
182180}
183181
@@ -189,7 +187,7 @@ func KeepAlive(keepalive bool) func(*Attacker) {
189187 tr .DisableKeepAlives = ! keepalive
190188 if ! keepalive {
191189 a .dialer .KeepAlive = 0
192- tr .Dial = a .dialer .Dial
190+ tr .DialContext = a .dialer .DialContext
193191 }
194192 }
195193}
@@ -223,8 +221,8 @@ func H2C(enabled bool) func(*Attacker) {
223221 if tr := a .client .Transport .(* http.Transport ); enabled {
224222 a .client .Transport = & http2.Transport {
225223 AllowHTTP : true ,
226- DialTLS : func (network , addr string , cfg * tls.Config ) (net.Conn , error ) {
227- return tr .Dial ( network , addr )
224+ DialTLSContext : func (ctx context. Context , network , addr string , cfg * tls.Config ) (net.Conn , error ) {
225+ return tr .DialContext ( ctx , network , addr )
228226 },
229227 }
230228 }
@@ -263,6 +261,119 @@ func ProxyHeader(h http.Header) func(*Attacker) {
263261 }
264262}
265263
264+ // DNSCaching returns a functional option that enables DNS caching for
265+ // the given ttl. When ttl is zero cached entries will never expire.
266+ // When ttl is non-zero, this will start a refresh go-routine that updates
267+ // the cache every ttl interval. This go-routine will be stopped when the
268+ // attack is stopped.
269+ // When the ttl is negative, no caching will be performed.
270+ func DNSCaching (ttl time.Duration ) func (* Attacker ) {
271+ return func (a * Attacker ) {
272+ if ttl < 0 {
273+ return
274+ }
275+
276+ if tr , ok := a .client .Transport .(* http.Transport ); ok {
277+ dial := tr .DialContext
278+ if dial == nil {
279+ dial = a .dialer .DialContext
280+ }
281+
282+ resolver := & dnscache.Resolver {}
283+
284+ if ttl != 0 {
285+ go func () {
286+ refresh := time .NewTicker (ttl )
287+ defer refresh .Stop ()
288+ for {
289+ select {
290+ case <- refresh .C :
291+ resolver .Refresh (true )
292+ case <- a .stopch :
293+ return
294+ }
295+ }
296+ }()
297+ }
298+
299+ rng := rand .New (rand .NewSource (time .Now ().UnixNano ()))
300+
301+ tr .DialContext = func (ctx context.Context , network , addr string ) (conn net.Conn , err error ) {
302+ host , port , err := net .SplitHostPort (addr )
303+ if err != nil {
304+ return nil , err
305+ }
306+
307+ ips , err := resolver .LookupHost (ctx , host )
308+ if err != nil {
309+ return nil , err
310+ }
311+
312+ if len (ips ) == 0 {
313+ return nil , & net.DNSError {Err : "no such host" , Name : addr }
314+ }
315+
316+ // Pick a random IP from each IP family and dial each concurrently.
317+ // The first that succeeds wins, the other gets canceled.
318+
319+ rng .Shuffle (len (ips ), func (i , j int ) { ips [i ], ips [j ] = ips [j ], ips [i ] })
320+
321+ // In place filtering of ips to only include the first IPv4 and IPv6.
322+ j := 0
323+ for i := 0 ; i < len (ips ) && j < 2 ; i ++ {
324+ ip := net .ParseIP (ips [i ])
325+ switch {
326+ case len (ip .To4 ()) == net .IPv4len && j == 0 :
327+ fallthrough
328+ case len (ip ) == net .IPv6len && j == 1 :
329+ ips [j ] = ips [i ]
330+ j ++
331+ }
332+ }
333+ ips = ips [:j ]
334+
335+ type result struct {
336+ conn net.Conn
337+ err error
338+ }
339+
340+ ch := make (chan result , len (ips ))
341+ ctx , cancel := context .WithCancel (ctx )
342+ defer cancel ()
343+
344+ for _ , ip := range ips {
345+ go func (ip string ) {
346+ conn , err := dial (ctx , network , net .JoinHostPort (ip , port ))
347+ ch <- result {conn , err }
348+ }(ip )
349+ }
350+
351+ for i := 0 ; i < cap (ch ); i ++ {
352+ select {
353+ case <- ctx .Done ():
354+ return nil , ctx .Err ()
355+ case r := <- ch :
356+ if err = r .err ; err != nil {
357+ continue
358+ }
359+ return r .conn , nil
360+ }
361+ }
362+
363+ return nil , err
364+ }
365+ }
366+ }
367+ }
368+
369+ type attack struct {
370+ name string
371+ began time.Time
372+
373+ seqmu sync.Mutex
374+ seq uint64
375+ }
376+
266377// Attack reads its Targets from the passed Targeter and attacks them at
267378// the rate specified by the Pacer. When the duration is zero the attack
268379// runs until Stop is called. Results are sent to the returned channel as soon
@@ -275,21 +386,29 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
275386 workers = a .maxWorkers
276387 }
277388
389+ atk := & attack {
390+ name : name ,
391+ began : time .Now (),
392+ }
393+
278394 results := make (chan * Result )
279395 ticks := make (chan struct {})
280396 for i := uint64 (0 ); i < workers ; i ++ {
281397 wg .Add (1 )
282- go a .attack (tr , name , & wg , ticks , results )
398+ go a .attack (tr , atk , & wg , ticks , results )
283399 }
284400
285401 go func () {
286- defer close (results )
287- defer wg .Wait ()
288- defer close (ticks )
289-
290- began , count := time .Now (), uint64 (0 )
402+ defer func () {
403+ close (ticks )
404+ wg .Wait ()
405+ close (results )
406+ a .Stop ()
407+ }()
408+
409+ count := uint64 (0 )
291410 for {
292- elapsed := time .Since (began )
411+ elapsed := time .Since (atk . began )
293412 if du > 0 && elapsed > du {
294413 return
295414 }
@@ -312,7 +431,7 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
312431 // all workers are blocked. start one more and try again
313432 workers ++
314433 wg .Add (1 )
315- go a .attack (tr , name , & wg , ticks , results )
434+ go a .attack (tr , atk , & wg , ticks , results )
316435 }
317436 }
318437
@@ -342,25 +461,25 @@ func (a *Attacker) Stop() bool {
342461 }
343462}
344463
345- func (a * Attacker ) attack (tr Targeter , name string , workers * sync.WaitGroup , ticks <- chan struct {}, results chan <- * Result ) {
464+ func (a * Attacker ) attack (tr Targeter , atk * attack , workers * sync.WaitGroup , ticks <- chan struct {}, results chan <- * Result ) {
346465 defer workers .Done ()
347466 for range ticks {
348- results <- a .hit (tr , name )
467+ results <- a .hit (tr , atk )
349468 }
350469}
351470
352- func (a * Attacker ) hit (tr Targeter , name string ) * Result {
471+ func (a * Attacker ) hit (tr Targeter , atk * attack ) * Result {
353472 var (
354- res = Result {Attack : name }
473+ res = Result {Attack : atk . name }
355474 tgt Target
356475 err error
357476 )
358477
359- a .seqmu .Lock ()
360- res .Timestamp = a .began .Add (time .Since (a .began ))
361- res .Seq = a .seq
362- a .seq ++
363- a .seqmu .Unlock ()
478+ atk .seqmu .Lock ()
479+ res .Timestamp = atk .began .Add (time .Since (atk .began ))
480+ res .Seq = atk .seq
481+ atk .seq ++
482+ atk .seqmu .Unlock ()
364483
365484 defer func () {
366485 res .Latency = time .Since (res .Timestamp )
@@ -382,8 +501,8 @@ func (a *Attacker) hit(tr Targeter, name string) *Result {
382501 return & res
383502 }
384503
385- if name != "" {
386- req .Header .Set ("X-Vegeta-Attack" , name )
504+ if atk . name != "" {
505+ req .Header .Set ("X-Vegeta-Attack" , atk . name )
387506 }
388507
389508 req .Header .Set ("X-Vegeta-Seq" , strconv .FormatUint (res .Seq , 10 ))
0 commit comments