@@ -13,6 +13,7 @@ import (
1313 "fmt"
1414 "io"
1515 "net"
16+ "sync"
1617 "time"
1718
1819 "github.com/blinklabs-io/cdnsd/internal/handshake"
@@ -31,6 +32,16 @@ type Peer struct {
3132 address string
3233 conn net.Conn
3334 networkMagic uint32
35+ mu sync.Mutex
36+ sendMu sync.Mutex
37+ hasConnected bool
38+ doneCh chan struct {}
39+ errorCh chan error
40+ handshakeCh chan Message
41+ headersCh chan Message
42+ blockCh chan Message
43+ addrCh chan Message
44+ proofCh chan Message
3445}
3546
3647// NewPeer returns a new Peer using an existing connection (if provided) and the specified network magic. If a connection is provided,
@@ -39,10 +50,15 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
3950 p := & Peer {
4051 conn : conn ,
4152 networkMagic : networkMagic ,
53+ doneCh : make (chan struct {}),
54+ errorCh : make (chan error , 5 ),
4255 }
4356 if conn != nil {
57+ p .conn = conn
4458 p .address = conn .RemoteAddr ().String ()
45- if err := p .handshake (); err != nil {
59+ p .hasConnected = true
60+ if err := p .setupConnection (); err != nil {
61+ _ = p .conn .Close ()
4662 return nil , err
4763 }
4864 }
@@ -51,34 +67,69 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
5167
5268// Connect establishes a connection with a peer and performs the handshake process
5369func (p * Peer ) Connect (address string ) error {
70+ p .mu .Lock ()
71+ defer p .mu .Unlock ()
5472 if p .conn != nil {
5573 return errors .New ("connection already established" )
5674 }
75+ if p .hasConnected {
76+ return errors .New ("peer cannot be reused after disconnect" )
77+ }
5778 var err error
5879 p .conn , err = net .DialTimeout ("tcp" , address , dialTimeout )
5980 if err != nil {
6081 return err
6182 }
6283 p .address = address
63- if err := p .handshake (); err != nil {
84+ p .hasConnected = true
85+ if err := p .setupConnection (); err != nil {
86+ _ = p .conn .Close ()
6487 return err
6588 }
6689 return nil
6790}
6891
6992// Close closes an active connection with a network peer
7093func (p * Peer ) Close () error {
94+ p .mu .Lock ()
95+ defer p .mu .Unlock ()
7196 if p .conn == nil {
7297 return errors .New ("connection is not established" )
7398 }
7499 if err := p .conn .Close (); err != nil {
75100 return err
76101 }
102+ p .conn = nil
103+ close (p .doneCh )
104+ return nil
105+ }
106+
107+ // ErrorChan returns the async error channel
108+ func (p * Peer ) ErrorChan () <- chan error {
109+ return p .errorCh
110+ }
111+
112+ // setupConnection runs the initial handshake and starts the receive loop
113+ func (p * Peer ) setupConnection () error {
114+ // Init channels for async messages
115+ p .handshakeCh = make (chan Message , 10 )
116+ p .headersCh = make (chan Message , 10 )
117+ p .blockCh = make (chan Message , 10 )
118+ p .addrCh = make (chan Message , 10 )
119+ p .proofCh = make (chan Message , 10 )
120+ // Start receive loop
121+ go p .recvLoop ()
122+ // Start handshake
123+ if err := p .handshake (); err != nil {
124+ return err
125+ }
77126 return nil
78127}
79128
80129// sendMessage encodes and sends a message with the given type and payload
81130func (p * Peer ) sendMessage (msgType uint8 , msgPayload Message ) error {
131+ p .sendMu .Lock ()
132+ defer p .sendMu .Unlock ()
82133 if p .conn == nil {
83134 return errors .New ("connection is not established" )
84135 }
@@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
97148 return nil
98149}
99150
100- // receiveMessage receives and decodes messages from the active connection
101- func (p * Peer ) receiveMessage () (Message , error ) {
102- headerBuf := make ([]byte , messageHeaderLength )
103- if _ , err := io .ReadFull (p .conn , headerBuf ); err != nil {
104- return nil , fmt .Errorf ("read header: %w" , err )
105- }
106- header := new (msgHeader )
107- if err := header .Decode (headerBuf ); err != nil {
108- return nil , fmt .Errorf ("header decode: %w" , err )
109- }
110- if header .NetworkMagic != p .networkMagic {
111- return nil , fmt .Errorf ("invalid network magic: %d" , header .NetworkMagic )
112- }
113- if header .PayloadLength > messageMaxPayloadLength {
114- return nil , errors .New ("payload is too large" )
115- }
116- payload := make ([]byte , header .PayloadLength )
117- if _ , err := io .ReadFull (p .conn , payload ); err != nil {
118- return nil , fmt .Errorf ("read payload: %w" , err )
119- }
120- msg , err := decodeMessage (header , payload )
121- if err != nil {
122- // Discard unsupported messages and try to get another message
123- // This is a bit of a hack
124- var unsupportedErr UnsupportedMessageTypeError
125- if errors .As (err , & unsupportedErr ) {
126- return p .receiveMessage ()
151+ // recvLoop receives and decodes messages from the active connection
152+ func (p * Peer ) recvLoop () {
153+ err := func () error {
154+ // Assign to local var to avoid nil deref panic on shutdown
155+ conn := p .conn
156+ for {
157+ headerBuf := make ([]byte , messageHeaderLength )
158+ if _ , err := io .ReadFull (conn , headerBuf ); err != nil {
159+ return fmt .Errorf ("read header: %w" , err )
160+ }
161+ header := new (msgHeader )
162+ if err := header .Decode (headerBuf ); err != nil {
163+ return fmt .Errorf ("header decode: %w" , err )
164+ }
165+ if header .NetworkMagic != p .networkMagic {
166+ return fmt .Errorf ("invalid network magic: %d" , header .NetworkMagic )
167+ }
168+ if header .PayloadLength > messageMaxPayloadLength {
169+ return errors .New ("payload is too large" )
170+ }
171+ payload := make ([]byte , header .PayloadLength )
172+ if _ , err := io .ReadFull (conn , payload ); err != nil {
173+ return fmt .Errorf ("read payload: %w" , err )
174+ }
175+ msg , err := decodeMessage (header , payload )
176+ if err != nil {
177+ // Discard unsupported messages and try to get another message
178+ // This is a bit of a hack
179+ var unsupportedErr UnsupportedMessageTypeError
180+ if errors .As (err , & unsupportedErr ) {
181+ continue
182+ }
183+ return fmt .Errorf ("decode message: %w" , err )
184+ }
185+ if err := p .handleMessage (msg ); err != nil {
186+ return fmt .Errorf ("handle message: %w" , err )
187+ }
127188 }
128- return nil , err
189+ }()
190+ if err != nil {
191+ p .errorCh <- err
192+ _ = p .Close ()
129193 }
130- return msg , nil
194+ }
195+
196+ func (p * Peer ) handleMessage (msg Message ) error {
197+ switch msg .(type ) {
198+ case * MsgVersion , * MsgVerack :
199+ p .handshakeCh <- msg
200+ case * MsgAddr :
201+ p .addrCh <- msg
202+ case * MsgHeaders :
203+ p .headersCh <- msg
204+ case * MsgBlock :
205+ p .blockCh <- msg
206+ case * MsgProof :
207+ p .proofCh <- msg
208+ default :
209+ return fmt .Errorf ("unknown message type: %T" , msg )
210+ }
211+ return nil
131212}
132213
133214// handshake performs the handshake process, which involves exchanging Version messages with the network peer
@@ -157,20 +238,28 @@ func (p *Peer) handshake() error {
157238 return err
158239 }
159240 // Wait for Verack response
160- msg , err := p .receiveMessage ()
161- if err != nil {
162- return err
163- }
164- if _ , ok := msg .(* MsgVerack ); ! ok {
165- return fmt .Errorf ("unexpected message: %T" , msg )
241+ select {
242+ case msg := <- p .handshakeCh :
243+ if _ , ok := msg .(* MsgVerack ); ! ok {
244+ return fmt .Errorf ("unexpected message: %T" , msg )
245+ }
246+ case err := <- p .errorCh :
247+ return fmt .Errorf ("handshake failed: %w" , err )
248+ case <- time .After (1 * time .Second ):
249+ return errors .New ("handshake timed out" )
166250 }
167251 // Wait for Version from peer
168- msg , err = p .receiveMessage ()
169- if err != nil {
170- return err
171- }
172- if _ , ok := msg .(* MsgVersion ); ! ok {
173- return fmt .Errorf ("unexpected message: %T" , msg )
252+ select {
253+ case msg := <- p .handshakeCh :
254+ if _ , ok := msg .(* MsgVersion ); ! ok {
255+ return fmt .Errorf ("unexpected message: %T" , msg )
256+ }
257+ case err := <- p .errorCh :
258+ return fmt .Errorf ("handshake failed: %w" , err )
259+ case <- p .doneCh :
260+ return errors .New ("connection has shut down" )
261+ case <- time .After (1 * time .Second ):
262+ return errors .New ("handshake timed out" )
174263 }
175264 // Send Verack
176265 if err := p .sendMessage (MessageVerack , nil ); err != nil {
@@ -185,15 +274,18 @@ func (p *Peer) GetPeers() ([]NetAddress, error) {
185274 return nil , err
186275 }
187276 // Wait for Addr response
188- msg , err := p .receiveMessage ()
189- if err != nil {
190- return nil , err
191- }
192- msgAddr , ok := msg .(* MsgAddr )
193- if ! ok {
194- return nil , fmt .Errorf ("unexpected message: %T" , msg )
277+ select {
278+ case msg := <- p .addrCh :
279+ msgAddr , ok := msg .(* MsgAddr )
280+ if ! ok {
281+ return nil , fmt .Errorf ("unexpected message: %T" , msg )
282+ }
283+ return msgAddr .Peers , nil
284+ case <- p .doneCh :
285+ return nil , errors .New ("connection has shut down" )
286+ case <- time .After (5 * time .Second ):
287+ return nil , errors .New ("timed out" )
195288 }
196- return msgAddr .Peers , nil
197289}
198290
199291// GetHeaders requests a list of headers from the network peer
@@ -206,15 +298,18 @@ func (p *Peer) GetHeaders(locator [][32]byte, stopHash [32]byte) ([]*handshake.B
206298 return nil , err
207299 }
208300 // Wait for Headers response
209- msg , err := p .receiveMessage ()
210- if err != nil {
211- return nil , err
212- }
213- msgHeaders , ok := msg .(* MsgHeaders )
214- if ! ok {
215- return nil , fmt .Errorf ("unexpected message: %T" , msg )
301+ select {
302+ case msg := <- p .headersCh :
303+ msgHeaders , ok := msg .(* MsgHeaders )
304+ if ! ok {
305+ return nil , fmt .Errorf ("unexpected message: %T" , msg )
306+ }
307+ return msgHeaders .Headers , nil
308+ case <- p .doneCh :
309+ return nil , errors .New ("connection has shut down" )
310+ case <- time .After (5 * time .Second ):
311+ return nil , errors .New ("timed out" )
216312 }
217- return msgHeaders .Headers , nil
218313}
219314
220315// GetProof requests a proof for a domain name from the network peer
@@ -228,15 +323,18 @@ func (p *Peer) GetProof(name string, rootHash [32]byte) (*handshake.Proof, error
228323 return nil , err
229324 }
230325 // Wait for Proof response
231- msg , err := p .receiveMessage ()
232- if err != nil {
233- return nil , err
234- }
235- msgProof , ok := msg .(* MsgProof )
236- if ! ok {
237- return nil , fmt .Errorf ("unexpected message: %T" , msg )
326+ select {
327+ case msg := <- p .proofCh :
328+ msgProof , ok := msg .(* MsgProof )
329+ if ! ok {
330+ return nil , fmt .Errorf ("unexpected message: %T" , msg )
331+ }
332+ return msgProof .Proof , nil
333+ case <- p .doneCh :
334+ return nil , errors .New ("connection has shut down" )
335+ case <- time .After (5 * time .Second ):
336+ return nil , errors .New ("timed out" )
238337 }
239- return msgProof .Proof , nil
240338}
241339
242340// GetBlock requests the specified block from the network peer
@@ -253,13 +351,16 @@ func (p *Peer) GetBlock(hash [32]byte) (*handshake.Block, error) {
253351 return nil , err
254352 }
255353 // Wait for Block response
256- msg , err := p .receiveMessage ()
257- if err != nil {
258- return nil , err
259- }
260- msgBlock , ok := msg .(* MsgBlock )
261- if ! ok {
262- return nil , fmt .Errorf ("unexpected message: %T" , msg )
354+ select {
355+ case msg := <- p .blockCh :
356+ msgBlock , ok := msg .(* MsgBlock )
357+ if ! ok {
358+ return nil , fmt .Errorf ("unexpected message: %T" , msg )
359+ }
360+ return msgBlock .Block , nil
361+ case <- p .doneCh :
362+ return nil , errors .New ("connection has shut down" )
363+ case <- time .After (5 * time .Second ):
364+ return nil , errors .New ("timed out" )
263365 }
264- return msgBlock .Block , nil
265366}
0 commit comments