77 "google.golang.org/grpc/balancer"
88 "google.golang.org/grpc/codes"
99 "google.golang.org/grpc/connectivity"
10+ "google.golang.org/grpc/experimental/stats"
1011 "google.golang.org/grpc/resolver"
1112 "google.golang.org/grpc/status"
1213 gc "gopkg.in/check.v1"
@@ -43,8 +44,8 @@ func (s *DispatcherSuite) TestContextAdapters(c *gc.C) {
4344}
4445
4546func (s * DispatcherSuite ) TestDispatchCases (c * gc.C ) {
46- var cc mockClientConn
47- var disp = dispatcherBuilder {zone : "local" }.Build (& cc , balancer.BuildOptions {}).(* dispatcher )
47+ var cc = newMockClientConn ()
48+ var disp = dispatcherBuilder {zone : "local" }.Build (cc , balancer.BuildOptions {}).(* dispatcher )
4849 cc .disp = disp
4950 close (disp .sweepDoneCh ) // Disable async sweeping.
5051
@@ -68,7 +69,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
6869 result , err := disp .Pick (balancer.PickInfo {Ctx : ctx })
6970 c .Check (err , gc .IsNil )
7071 c .Check (result .Done , gc .IsNil )
71- c .Check (result .SubConn , gc .Equals , mockSubConn { Name : "default.addr:80" , disp : disp } )
72+ c .Check (result .SubConn .( * testSubConnWrapper ). name , gc .Equals , "default.addr:80" )
7273
7374 // Case: Specific remote peer is dispatched to.
7475 ctx = WithDispatchRoute (context .Background (),
@@ -84,7 +85,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
8485 result , err = disp .Pick (balancer.PickInfo {Ctx : ctx })
8586 c .Check (err , gc .IsNil )
8687 c .Check (result .Done , gc .IsNil )
87- c .Check (result .SubConn , gc .Equals , mockSubConn { Name : "remote.addr:80" , disp : disp } )
88+ c .Check (result .SubConn .( * testSubConnWrapper ). name , gc .Equals , "remote.addr:80" )
8889
8990 // Case: Route allows for multiple members. A local one is now dialed.
9091 ctx = WithDispatchRoute (context .Background (), buildRouteFixture (), ProcessSpec_ID {})
@@ -99,7 +100,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
99100 result , err = disp .Pick (balancer.PickInfo {Ctx : ctx })
100101 c .Check (err , gc .IsNil )
101102 c .Check (result .Done , gc .IsNil )
102- c .Check (result .SubConn , gc .Equals , mockSubConn { Name : "local.addr:80" , disp : disp } )
103+ c .Check (result .SubConn .( * testSubConnWrapper ). name , gc .Equals , "local.addr:80" )
103104
104105 // Case: One local addr is marked as failed. Another is dialed.
105106 mockSubConn {Name : "local.addr:80" , disp : disp }.UpdateState (balancer.SubConnState {ConnectivityState : connectivity .TransientFailure })
@@ -114,7 +115,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
114115 result , err = disp .Pick (balancer.PickInfo {Ctx : ctx })
115116 c .Check (err , gc .IsNil )
116117 c .Check (result .Done , gc .IsNil )
117- c .Check (result .SubConn , gc .Equals , mockSubConn { Name : "local.otherAddr:80" , disp : disp } )
118+ c .Check (result .SubConn .( * testSubConnWrapper ). name , gc .Equals , "local.otherAddr:80" )
118119
119120 // Case: otherAddr is also failed. Expect that an error is returned,
120121 // rather than dispatch to remote addr. (Eg we prefer to wait for a
@@ -151,7 +152,7 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
151152 result , err = disp .Pick (balancer.PickInfo {Ctx : ctx })
152153 c .Check (err , gc .IsNil )
153154 c .Check (result .Done , gc .NotNil )
154- c .Check (result .SubConn , gc .Equals , mockSubConn { Name : "local.addr:80" , disp : disp } )
155+ c .Check (result .SubConn .( * testSubConnWrapper ). name , gc .Equals , "local.addr:80" )
155156
156157 // Closure callback with an Unavailable error (only) will trigger an invalidation.
157158 result .Done (balancer.DoneInfo {Err : nil })
@@ -163,8 +164,8 @@ func (s *DispatcherSuite) TestDispatchCases(c *gc.C) {
163164}
164165
165166func (s * DispatcherSuite ) TestDispatchMarkAndSweep (c * gc.C ) {
166- var cc mockClientConn
167- var disp = dispatcherBuilder {zone : "local" }.Build (& cc , balancer.BuildOptions {}).(* dispatcher )
167+ var cc = newMockClientConn ()
168+ var disp = dispatcherBuilder {zone : "local" }.Build (cc , balancer.BuildOptions {}).(* dispatcher )
168169 cc .disp = disp
169170 defer disp .Close ()
170171
@@ -233,45 +234,103 @@ func (s *DispatcherSuite) TestDispatchMarkAndSweep(c *gc.C) {
233234 c .Check (err , gc .IsNil )
234235}
235236
236- type mockClientConn struct {
237- err error
238- created [] mockSubConn
239- removed [] mockSubConn
240- disp * dispatcher
237+ // testSubConnWrapper wraps a test SubConn to track operations
238+ type testSubConnWrapper struct {
239+ balancer. SubConn
240+ name string
241+ disp * dispatcher
241242}
242243
244+ // mockSubConn represents a test SubConn for comparisons and state updates
243245type mockSubConn struct {
244246 Name string
245247 disp * dispatcher
246248}
247249
248- func (s1 mockSubConn ) Equal (s2 mockSubConn ) bool {
249- return s1 .Name == s2 .Name
250+ // mockClientConn implements balancer.ClientConn for testing
251+ type mockClientConn struct {
252+ balancer.ClientConn
253+ err error
254+ created []mockSubConn
255+ removed []mockSubConn
256+ disp * dispatcher
257+ subConns map [string ]* testSubConnWrapper
258+ target string
250259}
251260
252- func (s mockSubConn ) UpdateAddresses ([]resolver.Address ) { panic ("deprecated" ) }
253- func (s mockSubConn ) UpdateState (state balancer.SubConnState ) { s .disp .updateSubConnState (s , state ) }
254- func (s mockSubConn ) Connect () {}
255- func (s mockSubConn ) GetOrBuildProducer (balancer.ProducerBuilder ) (balancer.Producer , func ()) {
256- return nil , func () {}
257- }
258- func (s mockSubConn ) Shutdown () {
259- var c = s .disp .cc .(* mockClientConn )
260- c .removed = append (c .removed , s )
261+ func newMockClientConn () * mockClientConn {
262+ return & mockClientConn {
263+ subConns : make (map [string ]* testSubConnWrapper ),
264+ target : "default.addr:80" , // Default target for tests
265+ }
261266}
262267
263- func (c * mockClientConn ) NewSubConn (a []resolver.Address , _ balancer.NewSubConnOptions ) (balancer.SubConn , error ) {
264- var sc = mockSubConn {Name : a [0 ].Addr , disp : c .disp }
265- c .created = append (c .created , sc )
266- return sc , c .err
268+ func (c * mockClientConn ) NewSubConn (a []resolver.Address , opts balancer.NewSubConnOptions ) (balancer.SubConn , error ) {
269+ if c .err != nil {
270+ return nil , c .err
271+ }
272+
273+ name := a [0 ].Addr
274+ sc := & testSubConnWrapper {
275+ name : name ,
276+ disp : c .disp ,
277+ }
278+
279+ c .subConns [name ] = sc
280+ c .created = append (c .created , mockSubConn {Name : name , disp : c .disp })
281+
282+ // StateListener is handled by the gRPC framework
283+
284+ return sc , nil
267285}
268286
269- func (c * mockClientConn ) UpdateAddresses (balancer.SubConn , []resolver.Address ) { panic ("deprecated" ) }
270- func (c * mockClientConn ) UpdateState (balancer.State ) {}
271- func (c * mockClientConn ) ResolveNow (resolver.ResolveNowOptions ) {}
272- func (c * mockClientConn ) Target () string { return "default.addr:80" }
287+ func (c * mockClientConn ) UpdateState (state balancer.State ) {}
288+
289+ func (c * mockClientConn ) ResolveNow (resolver.ResolveNowOptions ) {}
290+
291+ func (c * mockClientConn ) Target () string { return c .target }
292+
273293func (c * mockClientConn ) RemoveSubConn (sc balancer.SubConn ) {
274- sc .Shutdown ()
294+ if tsc , ok := sc .(* testSubConnWrapper ); ok {
295+ c .removed = append (c .removed , mockSubConn {Name : tsc .name , disp : tsc .disp })
296+ delete (c .subConns , tsc .name )
297+ }
298+ }
299+
300+ func (c * mockClientConn ) MetricsRecorder () stats.MetricsRecorder { return nil }
301+
302+ // Additional fields for testSubConnWrapper
303+ var _ balancer.SubConn = (* testSubConnWrapper )(nil )
304+
305+ func (s * testSubConnWrapper ) UpdateAddresses ([]resolver.Address ) { panic ("deprecated" ) }
306+
307+ func (s * testSubConnWrapper ) UpdateState (state balancer.SubConnState ) {
308+ if s .disp != nil {
309+ s .disp .updateSubConnState (s , state )
310+ }
311+ }
312+
313+ func (s * testSubConnWrapper ) Connect () {}
314+
315+ func (s * testSubConnWrapper ) GetOrBuildProducer (balancer.ProducerBuilder ) (balancer.Producer , func ()) {
316+ return nil , func () {}
317+ }
318+
319+ func (s * testSubConnWrapper ) Shutdown () {
320+ if cc , ok := s .disp .cc .(* mockClientConn ); ok {
321+ cc .removed = append (cc .removed , mockSubConn {Name : s .name , disp : s .disp })
322+ }
323+ }
324+
325+ func (s * testSubConnWrapper ) RegisterHealthListener (func (balancer.SubConnState )) {}
326+
327+ // Helper to create mockSubConn for UpdateState calls
328+ func (m mockSubConn ) UpdateState (state balancer.SubConnState ) {
329+ if cc , ok := m .disp .cc .(* mockClientConn ); ok {
330+ if sc , found := cc .subConns [m .Name ]; found {
331+ sc .UpdateState (state )
332+ }
333+ }
275334}
276335
277336type mockRouter struct { invalidated string }
0 commit comments