11use crate :: p2:: bindings:: sockets:: network:: { ErrorCode , IpAddressFamily , IpSocketAddress , Network } ;
22use crate :: p2:: bindings:: sockets:: udp;
3- use crate :: p2:: udp:: { IncomingDatagramStream , OutgoingDatagramStream , SendState , UdpState } ;
3+ use crate :: p2:: udp:: { IncomingDatagramStream , OutgoingDatagramStream , SendState } ;
44use crate :: p2:: { Pollable , SocketError , SocketResult } ;
5- use crate :: sockets:: util:: {
6- get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address,
7- receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
8- set_unicast_hop_limit, udp_bind, udp_disconnect,
9- } ;
5+ use crate :: sockets:: util:: { is_valid_address_family, is_valid_remote_address} ;
106use crate :: sockets:: {
11- MAX_UDP_DATAGRAM_SIZE , SocketAddrUse , SocketAddressFamily , WasiSocketsCtxView ,
7+ MAX_UDP_DATAGRAM_SIZE , SocketAddrUse , SocketAddressFamily , UdpSocket , WasiSocketsCtxView ,
128} ;
139use anyhow:: anyhow;
1410use async_trait:: async_trait;
15- use io_lifetimes:: AsSocketlike ;
16- use rustix:: io:: Errno ;
1711use std:: net:: SocketAddr ;
1812use tokio:: io:: Interest ;
1913use wasmtime:: component:: Resource ;
@@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
2822 network : Resource < Network > ,
2923 local_address : IpSocketAddress ,
3024 ) -> SocketResult < ( ) > {
31- self . ctx . allowed_network_uses . check_allowed_udp ( ) ?;
32-
33- match self . table . get ( & this) ?. udp_state {
34- UdpState :: Default => { }
35- UdpState :: BindStarted => return Err ( ErrorCode :: ConcurrencyConflict . into ( ) ) ,
36- UdpState :: Bound | UdpState :: Connected => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
37- }
38-
39- // Set the socket addr check on the socket so later functions have access to it through the socket handle
25+ let local_address = SocketAddr :: from ( local_address) ;
4026 let check = self . table . get ( & network) ?. socket_addr_check . clone ( ) ;
41- self . table
42- . get_mut ( & this) ?
43- . socket_addr_check
44- . replace ( check. clone ( ) ) ;
45-
46- let socket = self . table . get ( & this) ?;
47- let local_address: SocketAddr = local_address. into ( ) ;
48-
49- if !is_valid_address_family ( local_address. ip ( ) , socket. family ) {
50- return Err ( ErrorCode :: InvalidArgument . into ( ) ) ;
51- }
52-
53- {
54- check. check ( local_address, SocketAddrUse :: UdpBind ) . await ?;
55-
56- // Perform the OS bind call.
57- udp_bind ( socket. udp_socket ( ) , local_address) ?;
58- }
27+ check. check ( local_address, SocketAddrUse :: UdpBind ) . await ?;
5928
6029 let socket = self . table . get_mut ( & this) ?;
61- socket. udp_state = UdpState :: BindStarted ;
30+ socket. bind ( local_address) ?;
31+ socket. set_socket_addr_check ( Some ( check) ) ;
6232
6333 Ok ( ( ) )
6434 }
6535
6636 fn finish_bind ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < ( ) > {
67- let socket = self . table . get_mut ( & this) ?;
68-
69- match socket. udp_state {
70- UdpState :: BindStarted => {
71- socket. udp_state = UdpState :: Bound ;
72- Ok ( ( ) )
73- }
74- _ => Err ( ErrorCode :: NotInProgress . into ( ) ) ,
75- }
37+ self . table . get_mut ( & this) ?. finish_bind ( ) ?;
38+ Ok ( ( ) )
7639 }
7740
7841 async fn stream (
@@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
9558 let socket = self . table . get_mut ( & this) ?;
9659 let remote_address = remote_address. map ( SocketAddr :: from) ;
9760
98- match socket. udp_state {
99- UdpState :: Bound | UdpState :: Connected => { }
100- _ => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
61+ if !socket. is_bound ( ) {
62+ return Err ( ErrorCode :: InvalidState . into ( ) ) ;
10163 }
10264
10365 // We disconnect & (re)connect in two distinct steps for two reasons:
@@ -107,48 +69,29 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
10769 // if there isn't a disconnect in between.
10870
10971 // Step #1: Disconnect
110- if let UdpState :: Connected = socket. udp_state {
111- udp_disconnect ( socket. udp_socket ( ) ) ?;
112- socket. udp_state = UdpState :: Bound ;
72+ if socket. is_connected ( ) {
73+ socket. disconnect ( ) ?;
11374 }
11475
11576 // Step #2: (Re)connect
11677 if let Some ( connect_addr) = remote_address {
117- let Some ( check) = socket. socket_addr_check . as_ref ( ) else {
78+ let Some ( check) = socket. socket_addr_check ( ) else {
11879 return Err ( ErrorCode :: InvalidState . into ( ) ) ;
11980 } ;
120- if !is_valid_remote_address ( connect_addr)
121- || !is_valid_address_family ( connect_addr. ip ( ) , socket. family )
122- {
123- return Err ( ErrorCode :: InvalidArgument . into ( ) ) ;
124- }
12581 check. check ( connect_addr, SocketAddrUse :: UdpConnect ) . await ?;
126-
127- rustix:: net:: connect ( socket. udp_socket ( ) , & connect_addr) . map_err (
128- |error| match error {
129- Errno :: AFNOSUPPORT => ErrorCode :: InvalidArgument , // See `bind` implementation.
130- Errno :: INPROGRESS => {
131- tracing:: debug!(
132- "UDP connect returned EINPROGRESS, which should never happen"
133- ) ;
134- ErrorCode :: Unknown
135- }
136- _ => ErrorCode :: from ( error) ,
137- } ,
138- ) ?;
139- socket. udp_state = UdpState :: Connected ;
82+ socket. connect ( connect_addr) ?;
14083 }
14184
14285 let incoming_stream = IncomingDatagramStream {
143- inner : socket. inner . clone ( ) ,
86+ inner : socket. socket ( ) . clone ( ) ,
14487 remote_address,
14588 } ;
14689 let outgoing_stream = OutgoingDatagramStream {
147- inner : socket. inner . clone ( ) ,
90+ inner : socket. socket ( ) . clone ( ) ,
14891 remote_address,
149- family : socket. family ,
92+ family : socket. address_family ( ) ,
15093 send_state : SendState :: Idle ,
151- socket_addr_check : socket. socket_addr_check . clone ( ) ,
94+ socket_addr_check : socket. socket_addr_check ( ) . cloned ( ) ,
15295 } ;
15396
15497 Ok ( (
@@ -159,56 +102,25 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
159102
160103 fn local_address ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < IpSocketAddress > {
161104 let socket = self . table . get ( & this) ?;
162-
163- match socket. udp_state {
164- UdpState :: Default => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
165- UdpState :: BindStarted => return Err ( ErrorCode :: ConcurrencyConflict . into ( ) ) ,
166- _ => { }
167- }
168-
169- let addr = socket
170- . udp_socket ( )
171- . as_socketlike_view :: < std:: net:: UdpSocket > ( )
172- . local_addr ( ) ?;
173- Ok ( addr. into ( ) )
105+ Ok ( socket. local_address ( ) ?. into ( ) )
174106 }
175107
176108 fn remote_address ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < IpSocketAddress > {
177109 let socket = self . table . get ( & this) ?;
178-
179- match socket. udp_state {
180- UdpState :: Connected => { }
181- _ => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
182- }
183-
184- let addr = socket
185- . udp_socket ( )
186- . as_socketlike_view :: < std:: net:: UdpSocket > ( )
187- . peer_addr ( ) ?;
188- Ok ( addr. into ( ) )
110+ Ok ( socket. remote_address ( ) ?. into ( ) )
189111 }
190112
191113 fn address_family (
192114 & mut self ,
193115 this : Resource < udp:: UdpSocket > ,
194116 ) -> Result < IpAddressFamily , anyhow:: Error > {
195117 let socket = self . table . get ( & this) ?;
196-
197- match socket. family {
198- SocketAddressFamily :: Ipv4 => Ok ( IpAddressFamily :: Ipv4 ) ,
199- SocketAddressFamily :: Ipv6 => Ok ( IpAddressFamily :: Ipv6 ) ,
200- }
118+ Ok ( socket. address_family ( ) . into ( ) )
201119 }
202120
203121 fn unicast_hop_limit ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u8 > {
204122 let socket = self . table . get ( & this) ?;
205-
206- let ttl = match socket. family {
207- SocketAddressFamily :: Ipv4 => get_ip_ttl ( socket. udp_socket ( ) ) ?,
208- SocketAddressFamily :: Ipv6 => get_ipv6_unicast_hops ( socket. udp_socket ( ) ) ?,
209- } ;
210-
211- Ok ( ttl)
123+ Ok ( socket. unicast_hop_limit ( ) ?)
212124 }
213125
214126 fn set_unicast_hop_limit (
@@ -217,17 +129,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
217129 value : u8 ,
218130 ) -> SocketResult < ( ) > {
219131 let socket = self . table . get ( & this) ?;
220-
221- set_unicast_hop_limit ( socket. udp_socket ( ) , socket. family , value) ?;
222-
132+ socket. set_unicast_hop_limit ( value) ?;
223133 Ok ( ( ) )
224134 }
225135
226136 fn receive_buffer_size ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u64 > {
227137 let socket = self . table . get ( & this) ?;
228-
229- let value = receive_buffer_size ( socket. udp_socket ( ) ) ?;
230- Ok ( value)
138+ Ok ( socket. receive_buffer_size ( ) ?)
231139 }
232140
233141 fn set_receive_buffer_size (
@@ -236,33 +144,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
236144 value : u64 ,
237145 ) -> SocketResult < ( ) > {
238146 let socket = self . table . get ( & this) ?;
239-
240- set_receive_buffer_size ( socket. udp_socket ( ) , value) ?;
147+ socket. set_receive_buffer_size ( value) ?;
241148 Ok ( ( ) )
242149 }
243150
244151 fn send_buffer_size ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u64 > {
245152 let socket = self . table . get ( & this) ?;
246-
247- let value = send_buffer_size ( socket. udp_socket ( ) ) ?;
248- Ok ( value)
153+ Ok ( socket. send_buffer_size ( ) ?)
249154 }
250155
251- fn set_send_buffer_size (
252- & mut self ,
253- this : Resource < udp:: UdpSocket > ,
254- value : u64 ,
255- ) -> SocketResult < ( ) > {
156+ fn set_send_buffer_size ( & mut self , this : Resource < UdpSocket > , value : u64 ) -> SocketResult < ( ) > {
256157 let socket = self . table . get ( & this) ?;
257-
258- set_send_buffer_size ( socket. udp_socket ( ) , value) ?;
158+ socket. set_send_buffer_size ( value) ?;
259159 Ok ( ( ) )
260160 }
261161
262- fn subscribe (
263- & mut self ,
264- this : Resource < udp:: UdpSocket > ,
265- ) -> anyhow:: Result < Resource < DynPollable > > {
162+ fn subscribe ( & mut self , this : Resource < UdpSocket > ) -> anyhow:: Result < Resource < DynPollable > > {
266163 wasmtime_wasi_io:: poll:: subscribe ( self . table , this)
267164 }
268165
@@ -276,6 +173,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
276173 }
277174}
278175
176+ #[ async_trait]
177+ impl Pollable for UdpSocket {
178+ async fn ready ( & mut self ) {
179+ // None of the socket-level operations block natively
180+ }
181+ }
182+
279183impl udp:: HostIncomingDatagramStream for WasiSocketsCtxView < ' _ > {
280184 fn receive (
281185 & mut self ,
@@ -504,6 +408,15 @@ impl Pollable for OutgoingDatagramStream {
504408 }
505409}
506410
411+ impl From < SocketAddressFamily > for IpAddressFamily {
412+ fn from ( family : SocketAddressFamily ) -> IpAddressFamily {
413+ match family {
414+ SocketAddressFamily :: Ipv4 => IpAddressFamily :: Ipv4 ,
415+ SocketAddressFamily :: Ipv6 => IpAddressFamily :: Ipv6 ,
416+ }
417+ }
418+ }
419+
507420pub mod sync {
508421 use wasmtime:: component:: Resource ;
509422
0 commit comments