Skip to content

Commit ab5a8c2

Browse files
Add localAddress support to stream.connect (#362)
* Add `localAddress` support to `stream.connect` * fix windows * TransportAddress() instead of AnyAddress * tweak flags * Better flags * try to workaround nim 1.2 issue * Handle ReusePort in createStreamServer and improve tests * Rename ClientFlags to SocketFlags --------- Co-authored-by: Diego <[email protected]>
1 parent 229de5f commit ab5a8c2

File tree

3 files changed

+131
-6
lines changed

3 files changed

+131
-6
lines changed

chronos/transports/stream.nim

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ type
5050
# get stuck on transport `close()`.
5151
# Please use this flag only if you are making both client and server in
5252
# the same thread.
53-
TcpNoDelay
53+
TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay
54+
55+
SocketFlags* {.pure.} = enum
56+
TcpNoDelay,
57+
ReuseAddr,
58+
ReusePort
5459

5560

5661
StreamTransportTracker* = ref object of TrackerBase
@@ -699,7 +704,9 @@ when defined(windows):
699704
proc connect*(address: TransportAddress,
700705
bufferSize = DefaultStreamBufferSize,
701706
child: StreamTransport = nil,
702-
flags: set[TransportFlags] = {}): Future[StreamTransport] =
707+
localAddress = TransportAddress(),
708+
flags: set[SocketFlags] = {},
709+
): Future[StreamTransport] =
703710
## Open new connection to remote peer with address ``address`` and create
704711
## new transport object ``StreamTransport`` for established connection.
705712
## ``bufferSize`` is size of internal buffer for transport.
@@ -724,7 +731,35 @@ when defined(windows):
724731
retFuture.fail(getTransportOsError(osLastError()))
725732
return retFuture
726733

727-
if not(bindToDomain(sock, raddress.getDomain())):
734+
if SocketFlags.ReuseAddr in flags:
735+
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
736+
let err = osLastError()
737+
sock.closeSocket()
738+
retFuture.fail(getTransportOsError(err))
739+
return retFuture
740+
if SocketFlags.ReusePort in flags:
741+
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
742+
let err = osLastError()
743+
sock.closeSocket()
744+
retFuture.fail(getTransportOsError(err))
745+
return retFuture
746+
747+
if localAddress != TransportAddress():
748+
if localAddress.family != address.family:
749+
sock.closeSocket()
750+
retFuture.fail(newException(TransportOsError,
751+
"connect local address domain is not equal to target address domain"))
752+
return retFuture
753+
var
754+
localAddr: Sockaddr_storage
755+
localAddrLen: SockLen
756+
localAddress.toSAddr(localAddr, localAddrLen)
757+
if bindSocket(SocketHandle(sock),
758+
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
759+
sock.closeSocket()
760+
retFuture.fail(getTransportOsError(osLastError()))
761+
return retFuture
762+
elif not(bindToDomain(sock, raddress.getDomain())):
728763
let err = wsaGetLastError()
729764
sock.closeSocket()
730765
retFuture.fail(getTransportOsError(err))
@@ -1496,7 +1531,9 @@ else:
14961531
proc connect*(address: TransportAddress,
14971532
bufferSize = DefaultStreamBufferSize,
14981533
child: StreamTransport = nil,
1499-
flags: set[TransportFlags] = {}): Future[StreamTransport] =
1534+
localAddress = TransportAddress(),
1535+
flags: set[SocketFlags] = {},
1536+
): Future[StreamTransport] =
15001537
## Open new connection to remote peer with address ``address`` and create
15011538
## new transport object ``StreamTransport`` for established connection.
15021539
## ``bufferSize`` - size of internal buffer for transport.
@@ -1523,12 +1560,40 @@ else:
15231560
return retFuture
15241561

15251562
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
1526-
if TransportFlags.TcpNoDelay in flags:
1563+
if SocketFlags.TcpNoDelay in flags:
15271564
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
15281565
let err = osLastError()
15291566
sock.closeSocket()
15301567
retFuture.fail(getTransportOsError(err))
15311568
return retFuture
1569+
if SocketFlags.ReuseAddr in flags:
1570+
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
1571+
let err = osLastError()
1572+
sock.closeSocket()
1573+
retFuture.fail(getTransportOsError(err))
1574+
return retFuture
1575+
if SocketFlags.ReusePort in flags:
1576+
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
1577+
let err = osLastError()
1578+
sock.closeSocket()
1579+
retFuture.fail(getTransportOsError(err))
1580+
return retFuture
1581+
1582+
if localAddress != TransportAddress():
1583+
if localAddress.family != address.family:
1584+
sock.closeSocket()
1585+
retFuture.fail(newException(TransportOsError,
1586+
"connect local address domain is not equal to target address domain"))
1587+
return retFuture
1588+
var
1589+
localAddr: Sockaddr_storage
1590+
localAddrLen: SockLen
1591+
localAddress.toSAddr(localAddr, localAddrLen)
1592+
if bindSocket(SocketHandle(sock),
1593+
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
1594+
sock.closeSocket()
1595+
retFuture.fail(getTransportOsError(osLastError()))
1596+
return retFuture
15321597

15331598
proc continuation(udata: pointer) =
15341599
if not(retFuture.finished()):
@@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] =
17761841
retFuture.complete()
17771842
return retFuture
17781843

1844+
proc connect*(address: TransportAddress,
1845+
bufferSize = DefaultStreamBufferSize,
1846+
child: StreamTransport = nil,
1847+
flags: set[TransportFlags],
1848+
localAddress = TransportAddress()): Future[StreamTransport] =
1849+
# Retro compatibility with TransportFlags
1850+
var mappedFlags: set[SocketFlags]
1851+
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
1852+
address.connect(bufferSize, child, localAddress, mappedFlags)
1853+
17791854
proc close*(server: StreamServer) =
17801855
## Release ``server`` resources.
17811856
##
@@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress,
18641939
if sock == asyncInvalidSocket:
18651940
discard closeFd(SocketHandle(serverSocket))
18661941
raiseTransportOsError(err)
1942+
if ServerFlags.ReusePort in flags:
1943+
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
1944+
osdefs.SO_REUSEPORT, 1)):
1945+
let err = osLastError()
1946+
if sock == asyncInvalidSocket:
1947+
discard closeFd(SocketHandle(serverSocket))
1948+
raiseTransportOsError(err)
18671949
# TCP flags are not useful for Unix domain sockets.
18681950
if ServerFlags.TcpNoDelay in flags:
18691951
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,

tests/testasyncstream.nim

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ suite "TLSStream test suite":
958958
key = TLSPrivateKey.init(pemkey)
959959
cert = TLSCertificate.init(pemcert)
960960

961-
var server = createStreamServer(address, serveClient, {ReuseAddr})
961+
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
962962
server.start()
963963
var conn = await connect(address)
964964
var creader = newAsyncStreamReader(conn)

tests/teststream.nim

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
12591259
await allFutures(rtransp.closeWait(), wtransp.closeWait())
12601260
return buffer == message
12611261

1262+
proc testConnectBindLocalAddress() {.async.} =
1263+
let dst1 = initTAddress("127.0.0.1:33335")
1264+
let dst2 = initTAddress("127.0.0.1:33336")
1265+
let dst3 = initTAddress("127.0.0.1:33337")
1266+
1267+
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
1268+
await transp.closeWait()
1269+
1270+
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
1271+
# running the test multiple times or if a test ran previously used the same port.
1272+
let servers =
1273+
[createStreamServer(dst1, client, {ReuseAddr}),
1274+
createStreamServer(dst2, client, {ReuseAddr}),
1275+
createStreamServer(dst3, client, {ReusePort})]
1276+
1277+
for server in servers:
1278+
server.start()
1279+
1280+
let ta = initTAddress("0.0.0.0:35000")
1281+
1282+
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
1283+
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
1284+
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
1285+
1286+
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
1287+
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
1288+
1289+
expect(TransportOsError):
1290+
var transp2 = await connect(dst3, localAddress = ta)
1291+
1292+
expect(TransportOsError):
1293+
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))
1294+
1295+
await transp1.closeWait()
1296+
await transp2.closeWait()
1297+
await transp3.closeWait()
1298+
1299+
for server in servers:
1300+
server.stop()
1301+
await server.closeWait()
1302+
12621303
markFD = getCurrentFD()
12631304

12641305
for i in 0..<len(addresses):
@@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
13461387
check waitFor(testReadOnClose(addresses[i])) == true
13471388
test "[PIPE] readExactly()/write() test":
13481389
check waitFor(testPipe()) == true
1390+
test "[IP] bind connect to local address":
1391+
waitFor(testConnectBindLocalAddress())
13491392
test "Servers leak test":
13501393
check getTracker("stream.server").isLeaked() == false
13511394
test "Transports leak test":

0 commit comments

Comments
 (0)