Skip to content

Commit b77012e

Browse files
committed
refactor: async Handshake network support
Signed-off-by: Aurora Gaffney <[email protected]>
1 parent 764f347 commit b77012e

File tree

1 file changed

+176
-75
lines changed

1 file changed

+176
-75
lines changed

internal/handshake/protocol/peer.go

Lines changed: 176 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"io"
1515
"net"
16+
"sync"
1617
"time"
1718

1819
"github.com/blinklabs-io/cdnsd/internal/handshake"
@@ -31,6 +32,16 @@ type Peer struct {
3132
address string
3233
conn net.Conn
3334
networkMagic uint32
35+
mu sync.Mutex
36+
sendMu sync.Mutex
37+
hasConnected bool
38+
doneCh chan struct{}
39+
errorCh chan error
40+
handshakeCh chan Message
41+
headersCh chan Message
42+
blockCh chan Message
43+
addrCh chan Message
44+
proofCh chan Message
3445
}
3546

3647
// NewPeer returns a new Peer using an existing connection (if provided) and the specified network magic. If a connection is provided,
@@ -39,10 +50,15 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
3950
p := &Peer{
4051
conn: conn,
4152
networkMagic: networkMagic,
53+
doneCh: make(chan struct{}),
54+
errorCh: make(chan error, 5),
4255
}
4356
if conn != nil {
57+
p.conn = conn
4458
p.address = conn.RemoteAddr().String()
45-
if err := p.handshake(); err != nil {
59+
p.hasConnected = true
60+
if err := p.setupConnection(); err != nil {
61+
_ = p.conn.Close()
4662
return nil, err
4763
}
4864
}
@@ -51,34 +67,69 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
5167

5268
// Connect establishes a connection with a peer and performs the handshake process
5369
func (p *Peer) Connect(address string) error {
70+
p.mu.Lock()
71+
defer p.mu.Unlock()
5472
if p.conn != nil {
5573
return errors.New("connection already established")
5674
}
75+
if p.hasConnected {
76+
return errors.New("peer cannot be reused after disconnect")
77+
}
5778
var err error
5879
p.conn, err = net.DialTimeout("tcp", address, dialTimeout)
5980
if err != nil {
6081
return err
6182
}
6283
p.address = address
63-
if err := p.handshake(); err != nil {
84+
p.hasConnected = true
85+
if err := p.setupConnection(); err != nil {
86+
_ = p.conn.Close()
6487
return err
6588
}
6689
return nil
6790
}
6891

6992
// Close closes an active connection with a network peer
7093
func (p *Peer) Close() error {
94+
p.mu.Lock()
95+
defer p.mu.Unlock()
7196
if p.conn == nil {
7297
return errors.New("connection is not established")
7398
}
7499
if err := p.conn.Close(); err != nil {
75100
return err
76101
}
102+
p.conn = nil
103+
close(p.doneCh)
104+
return nil
105+
}
106+
107+
// ErrorChan returns the async error channel
108+
func (p *Peer) ErrorChan() <-chan error {
109+
return p.errorCh
110+
}
111+
112+
// setupConnection runs the initial handshake and starts the receive loop
113+
func (p *Peer) setupConnection() error {
114+
// Init channels for async messages
115+
p.handshakeCh = make(chan Message, 10)
116+
p.headersCh = make(chan Message, 10)
117+
p.blockCh = make(chan Message, 10)
118+
p.addrCh = make(chan Message, 10)
119+
p.proofCh = make(chan Message, 10)
120+
// Start receive loop
121+
go p.recvLoop()
122+
// Start handshake
123+
if err := p.handshake(); err != nil {
124+
return err
125+
}
77126
return nil
78127
}
79128

80129
// sendMessage encodes and sends a message with the given type and payload
81130
func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
131+
p.sendMu.Lock()
132+
defer p.sendMu.Unlock()
82133
if p.conn == nil {
83134
return errors.New("connection is not established")
84135
}
@@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
97148
return nil
98149
}
99150

100-
// receiveMessage receives and decodes messages from the active connection
101-
func (p *Peer) receiveMessage() (Message, error) {
102-
headerBuf := make([]byte, messageHeaderLength)
103-
if _, err := io.ReadFull(p.conn, headerBuf); err != nil {
104-
return nil, fmt.Errorf("read header: %w", err)
105-
}
106-
header := new(msgHeader)
107-
if err := header.Decode(headerBuf); err != nil {
108-
return nil, fmt.Errorf("header decode: %w", err)
109-
}
110-
if header.NetworkMagic != p.networkMagic {
111-
return nil, fmt.Errorf("invalid network magic: %d", header.NetworkMagic)
112-
}
113-
if header.PayloadLength > messageMaxPayloadLength {
114-
return nil, errors.New("payload is too large")
115-
}
116-
payload := make([]byte, header.PayloadLength)
117-
if _, err := io.ReadFull(p.conn, payload); err != nil {
118-
return nil, fmt.Errorf("read payload: %w", err)
119-
}
120-
msg, err := decodeMessage(header, payload)
121-
if err != nil {
122-
// Discard unsupported messages and try to get another message
123-
// This is a bit of a hack
124-
var unsupportedErr UnsupportedMessageTypeError
125-
if errors.As(err, &unsupportedErr) {
126-
return p.receiveMessage()
151+
// recvLoop receives and decodes messages from the active connection
152+
func (p *Peer) recvLoop() {
153+
err := func() error {
154+
// Assign to local var to avoid nil deref panic on shutdown
155+
conn := p.conn
156+
for {
157+
headerBuf := make([]byte, messageHeaderLength)
158+
if _, err := io.ReadFull(conn, headerBuf); err != nil {
159+
return fmt.Errorf("read header: %w", err)
160+
}
161+
header := new(msgHeader)
162+
if err := header.Decode(headerBuf); err != nil {
163+
return fmt.Errorf("header decode: %w", err)
164+
}
165+
if header.NetworkMagic != p.networkMagic {
166+
return fmt.Errorf("invalid network magic: %d", header.NetworkMagic)
167+
}
168+
if header.PayloadLength > messageMaxPayloadLength {
169+
return errors.New("payload is too large")
170+
}
171+
payload := make([]byte, header.PayloadLength)
172+
if _, err := io.ReadFull(conn, payload); err != nil {
173+
return fmt.Errorf("read payload: %w", err)
174+
}
175+
msg, err := decodeMessage(header, payload)
176+
if err != nil {
177+
// Discard unsupported messages and try to get another message
178+
// This is a bit of a hack
179+
var unsupportedErr UnsupportedMessageTypeError
180+
if errors.As(err, &unsupportedErr) {
181+
continue
182+
}
183+
return fmt.Errorf("decode message: %w", err)
184+
}
185+
if err := p.handleMessage(msg); err != nil {
186+
return fmt.Errorf("handle message: %w", err)
187+
}
127188
}
128-
return nil, err
189+
}()
190+
if err != nil {
191+
p.errorCh <- err
192+
_ = p.Close()
129193
}
130-
return msg, nil
194+
}
195+
196+
func (p *Peer) handleMessage(msg Message) error {
197+
switch msg.(type) {
198+
case *MsgVersion, *MsgVerack:
199+
p.handshakeCh <- msg
200+
case *MsgAddr:
201+
p.addrCh <- msg
202+
case *MsgHeaders:
203+
p.headersCh <- msg
204+
case *MsgBlock:
205+
p.blockCh <- msg
206+
case *MsgProof:
207+
p.proofCh <- msg
208+
default:
209+
return fmt.Errorf("unknown message type: %T", msg)
210+
}
211+
return nil
131212
}
132213

133214
// handshake performs the handshake process, which involves exchanging Version messages with the network peer
@@ -157,20 +238,28 @@ func (p *Peer) handshake() error {
157238
return err
158239
}
159240
// Wait for Verack response
160-
msg, err := p.receiveMessage()
161-
if err != nil {
162-
return err
163-
}
164-
if _, ok := msg.(*MsgVerack); !ok {
165-
return fmt.Errorf("unexpected message: %T", msg)
241+
select {
242+
case msg := <-p.handshakeCh:
243+
if _, ok := msg.(*MsgVerack); !ok {
244+
return fmt.Errorf("unexpected message: %T", msg)
245+
}
246+
case err := <-p.errorCh:
247+
return fmt.Errorf("handshake failed: %w", err)
248+
case <-time.After(1 * time.Second):
249+
return errors.New("handshake timed out")
166250
}
167251
// Wait for Version from peer
168-
msg, err = p.receiveMessage()
169-
if err != nil {
170-
return err
171-
}
172-
if _, ok := msg.(*MsgVersion); !ok {
173-
return fmt.Errorf("unexpected message: %T", msg)
252+
select {
253+
case msg := <-p.handshakeCh:
254+
if _, ok := msg.(*MsgVersion); !ok {
255+
return fmt.Errorf("unexpected message: %T", msg)
256+
}
257+
case err := <-p.errorCh:
258+
return fmt.Errorf("handshake failed: %w", err)
259+
case <-p.doneCh:
260+
return errors.New("connection has shut down")
261+
case <-time.After(1 * time.Second):
262+
return errors.New("handshake timed out")
174263
}
175264
// Send Verack
176265
if err := p.sendMessage(MessageVerack, nil); err != nil {
@@ -185,15 +274,18 @@ func (p *Peer) GetPeers() ([]NetAddress, error) {
185274
return nil, err
186275
}
187276
// Wait for Addr response
188-
msg, err := p.receiveMessage()
189-
if err != nil {
190-
return nil, err
191-
}
192-
msgAddr, ok := msg.(*MsgAddr)
193-
if !ok {
194-
return nil, fmt.Errorf("unexpected message: %T", msg)
277+
select {
278+
case msg := <-p.addrCh:
279+
msgAddr, ok := msg.(*MsgAddr)
280+
if !ok {
281+
return nil, fmt.Errorf("unexpected message: %T", msg)
282+
}
283+
return msgAddr.Peers, nil
284+
case <-p.doneCh:
285+
return nil, errors.New("connection has shut down")
286+
case <-time.After(5 * time.Second):
287+
return nil, errors.New("timed out")
195288
}
196-
return msgAddr.Peers, nil
197289
}
198290

199291
// GetHeaders requests a list of headers from the network peer
@@ -206,15 +298,18 @@ func (p *Peer) GetHeaders(locator [][32]byte, stopHash [32]byte) ([]*handshake.B
206298
return nil, err
207299
}
208300
// Wait for Headers response
209-
msg, err := p.receiveMessage()
210-
if err != nil {
211-
return nil, err
212-
}
213-
msgHeaders, ok := msg.(*MsgHeaders)
214-
if !ok {
215-
return nil, fmt.Errorf("unexpected message: %T", msg)
301+
select {
302+
case msg := <-p.headersCh:
303+
msgHeaders, ok := msg.(*MsgHeaders)
304+
if !ok {
305+
return nil, fmt.Errorf("unexpected message: %T", msg)
306+
}
307+
return msgHeaders.Headers, nil
308+
case <-p.doneCh:
309+
return nil, errors.New("connection has shut down")
310+
case <-time.After(5 * time.Second):
311+
return nil, errors.New("timed out")
216312
}
217-
return msgHeaders.Headers, nil
218313
}
219314

220315
// GetProof requests a proof for a domain name from the network peer
@@ -228,15 +323,18 @@ func (p *Peer) GetProof(name string, rootHash [32]byte) (*handshake.Proof, error
228323
return nil, err
229324
}
230325
// Wait for Proof response
231-
msg, err := p.receiveMessage()
232-
if err != nil {
233-
return nil, err
234-
}
235-
msgProof, ok := msg.(*MsgProof)
236-
if !ok {
237-
return nil, fmt.Errorf("unexpected message: %T", msg)
326+
select {
327+
case msg := <-p.proofCh:
328+
msgProof, ok := msg.(*MsgProof)
329+
if !ok {
330+
return nil, fmt.Errorf("unexpected message: %T", msg)
331+
}
332+
return msgProof.Proof, nil
333+
case <-p.doneCh:
334+
return nil, errors.New("connection has shut down")
335+
case <-time.After(5 * time.Second):
336+
return nil, errors.New("timed out")
238337
}
239-
return msgProof.Proof, nil
240338
}
241339

242340
// GetBlock requests the specified block from the network peer
@@ -253,13 +351,16 @@ func (p *Peer) GetBlock(hash [32]byte) (*handshake.Block, error) {
253351
return nil, err
254352
}
255353
// Wait for Block response
256-
msg, err := p.receiveMessage()
257-
if err != nil {
258-
return nil, err
259-
}
260-
msgBlock, ok := msg.(*MsgBlock)
261-
if !ok {
262-
return nil, fmt.Errorf("unexpected message: %T", msg)
354+
select {
355+
case msg := <-p.blockCh:
356+
msgBlock, ok := msg.(*MsgBlock)
357+
if !ok {
358+
return nil, fmt.Errorf("unexpected message: %T", msg)
359+
}
360+
return msgBlock.Block, nil
361+
case <-p.doneCh:
362+
return nil, errors.New("connection has shut down")
363+
case <-time.After(5 * time.Second):
364+
return nil, errors.New("timed out")
263365
}
264-
return msgBlock.Block, nil
265366
}

0 commit comments

Comments
 (0)