Skip to content

Commit de4f99b

Browse files
committed
add partial messages to gossipsub router
1 parent fefdb20 commit de4f99b

File tree

5 files changed

+502
-5
lines changed

5 files changed

+502
-5
lines changed

extensions.go

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package pubsub
22

33
import (
4+
"errors"
5+
"iter"
6+
7+
"github.com/libp2p/go-libp2p-pubsub/partialmessages"
48
pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb"
59
"github.com/libp2p/go-libp2p/core/peer"
610
)
711

812
type PeerExtensions struct {
9-
TestExtension bool
13+
TestExtension bool
14+
PartialMessages bool
1015
}
1116

1217
type TestExtensionConfig struct {
@@ -37,6 +42,7 @@ func peerExtensionsFromRPC(rpc *RPC) PeerExtensions {
3742
out := PeerExtensions{}
3843
if hasPeerExtensions(rpc) {
3944
out.TestExtension = rpc.Control.Extensions.GetTestExtension()
45+
out.PartialMessages = rpc.Control.Extensions.GetPartialMessages()
4046
}
4147
return out
4248
}
@@ -46,9 +52,19 @@ func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC {
4652
if rpc.Control == nil {
4753
rpc.Control = &pubsub_pb.ControlMessage{}
4854
}
49-
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{
50-
TestExtension: &pe.TestExtension,
55+
if rpc.Control.Extensions == nil {
56+
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
57+
}
58+
rpc.Control.Extensions.TestExtension = &pe.TestExtension
59+
}
60+
if pe.PartialMessages {
61+
if rpc.Control == nil {
62+
rpc.Control = &pubsub_pb.ControlMessage{}
63+
}
64+
if rpc.Control.Extensions == nil {
65+
rpc.Control.Extensions = &pubsub_pb.ControlExtensions{}
5166
}
67+
rpc.Control.Extensions.PartialMessages = &pe.PartialMessages
5268
}
5369
return rpc
5470
}
@@ -59,8 +75,9 @@ type extensionsState struct {
5975
sentExtensions map[peer.ID]struct{}
6076
reportMisbehavior func(peer.ID)
6177
sendRPC func(p peer.ID, r *RPC, urgent bool)
78+
testExtension *testExtension
6279

63-
testExtension *testExtension
80+
partialMessagesExtension *partialmessages.PartialMessageExtension
6481
}
6582

6683
func newExtensionsState(myExtensions PeerExtensions, reportMisbehavior func(peer.ID), sendRPC func(peer.ID, *RPC, bool)) *extensionsState {
@@ -132,14 +149,76 @@ func (es *extensionsState) extensionsAddPeer(id peer.ID) {
132149
if es.myExtensions.TestExtension && es.peerExtensions[id].TestExtension {
133150
es.testExtension.AddPeer(id)
134151
}
152+
153+
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
154+
es.partialMessagesExtension.AddPeer(id)
155+
}
135156
}
136157

137158
// extensionsRemovePeer is always called after extensionsAddPeer.
138159
func (es *extensionsState) extensionsRemovePeer(id peer.ID) {
160+
if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages {
161+
es.partialMessagesExtension.RemovePeer(id)
162+
}
139163
}
140164

141165
func (es *extensionsState) extensionsHandleRPC(rpc *RPC) {
142166
if es.myExtensions.TestExtension && es.peerExtensions[rpc.from].TestExtension {
143167
es.testExtension.HandleRPC(rpc.from, rpc.TestExtension)
144168
}
169+
170+
if es.myExtensions.PartialMessages && es.peerExtensions[rpc.from].PartialMessages && rpc.Partial != nil {
171+
es.partialMessagesExtension.HandleRPC(rpc.from, rpc.Partial)
172+
}
173+
}
174+
175+
func (es *extensionsState) Heartbeat() {
176+
if es.myExtensions.PartialMessages {
177+
es.partialMessagesExtension.Heartbeat()
178+
}
179+
}
180+
181+
func WithPartialMessagesExtension(pm *partialmessages.PartialMessageExtension) Option {
182+
return func(ps *PubSub) error {
183+
gs, ok := ps.rt.(*GossipSubRouter)
184+
if !ok {
185+
return errors.New("pubsub router is not gossipsub")
186+
}
187+
err := pm.Init(routerForPartialMessage{gs})
188+
if err != nil {
189+
return err
190+
}
191+
192+
gs.extensions.myExtensions.PartialMessages = true
193+
gs.extensions.partialMessagesExtension = pm
194+
return nil
195+
}
196+
}
197+
198+
type routerForPartialMessage struct {
199+
gs *GossipSubRouter
200+
}
201+
202+
// MeshPeers implements partialmessages.Router.
203+
func (r routerForPartialMessage) MeshPeers(topic string) iter.Seq[peer.ID] {
204+
return func(yield func(peer.ID) bool) {
205+
for peer := range r.gs.mesh[topic] {
206+
if exts := r.gs.extensions.peerExtensions[peer]; exts.PartialMessages {
207+
if !yield(peer) {
208+
return
209+
}
210+
}
211+
}
212+
}
213+
}
214+
215+
// SendRPC implements partialmessages.Router.
216+
func (r routerForPartialMessage) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) {
217+
r.gs.sendRPC(p, &RPC{
218+
RPC: pubsub_pb.RPC{
219+
Partial: rpc,
220+
},
221+
}, urgent)
145222
}
223+
224+
var _ partialmessages.Router = routerForPartialMessage{}

gossipsub.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,8 @@ func (gs *GossipSubRouter) heartbeat() {
18331833

18341834
// advance the message history window
18351835
gs.mcache.Shift()
1836+
1837+
gs.extensions.Heartbeat()
18361838
}
18371839

18381840
func (gs *GossipSubRouter) clearIHaveCounters() {

0 commit comments

Comments
 (0)