Skip to content

Commit 00dd668

Browse files
authored
Use the same UdpSocket in WASIp{2,3} (#11384)
* Use the same `UdpSocket` in WASIp{2,3} This commit refactors the implementation of `wasi:sockets` for WASIp2 and WASIp3 to use the same underlying host data structure for the `UdpSocket` resource in WIT. Previously each version of WASI had its own socket which resulted in duplicated code. There's some minor differences between WASIp2 and WASIp3 but it's easy enough to paper over with the same socket type. This is intended to help with the maintainability of this going forward to only have one type to operate on rather than two (which also ensures that bugfixes for one should affect the other). One other change made in this commit is that sprinkled checks for whether or not UDP is allowed are all removed and canonicalized during UDP socket creation. This means that UDP socket creation is the only location that checks for whether UDP is allowed. Once a UDP socket is created it can be used freely regardless of whether the UDP setting is enabled or disabled. This is not intended to have a large practical effect but it does mean the behavior of hosts that deny UDP but manually give access to a UDP socket resource to a component may behave subtly differently. * Review comments * Fix p3-less warnings * Update UDP denial test * Fix some clippy issues * Fix no-udp test warnings
1 parent 4ac219f commit 00dd668

File tree

13 files changed

+176
-304
lines changed

13 files changed

+176
-304
lines changed
Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
//! This test assumes that it will be run without udp support enabled
2-
use test_programs::wasi::sockets::{
3-
network::IpAddress,
4-
udp::{ErrorCode, IpAddressFamily, IpSocketAddress, Network, UdpSocket},
5-
};
62
7-
fn main() {
8-
let net = Network::default();
9-
let family = IpAddressFamily::Ipv4;
10-
let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321);
11-
let sock = UdpSocket::new(family).unwrap();
3+
#![deny(warnings)]
4+
use test_programs::wasi::sockets::udp::{ErrorCode, IpAddressFamily, UdpSocket};
125

13-
let bind = sock.blocking_bind(&net, remote1);
14-
eprintln!("Result of binding: {bind:?}");
15-
assert!(matches!(bind, Err(ErrorCode::AccessDenied)));
6+
fn main() {
7+
assert!(matches!(
8+
UdpSocket::new(IpAddressFamily::Ipv4),
9+
Err(ErrorCode::AccessDenied)
10+
));
1611
}

crates/wasi/src/p2/bindings.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ pub mod sync {
173173
"wasi:sockets/tcp/tcp-socket": super::super::sockets::tcp::TcpSocket,
174174
"wasi:sockets/udp/incoming-datagram-stream": super::super::sockets::udp::IncomingDatagramStream,
175175
"wasi:sockets/udp/outgoing-datagram-stream": super::super::sockets::udp::OutgoingDatagramStream,
176-
"wasi:sockets/udp/udp-socket": super::super::sockets::udp::UdpSocket,
176+
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,
177177

178178
// Error host trait from wasmtime-wasi-io is synchronous, so we can alias it
179179
"wasi:io/error": wasmtime_wasi_io::bindings::wasi::io::error,
@@ -394,7 +394,7 @@ mod async_io {
394394
// this crate
395395
"wasi:sockets/network/network": crate::p2::network::Network,
396396
"wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket,
397-
"wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket,
397+
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,
398398
"wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream,
399399
"wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream,
400400
"wasi:sockets/ip-name-lookup/resolve-address-stream": crate::p2::ip_name_lookup::ResolveAddressStream,

crates/wasi/src/p2/host/udp.rs

Lines changed: 46 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
22
use crate::p2::bindings::sockets::udp;
3-
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState};
3+
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState};
44
use crate::p2::{Pollable, SocketError, SocketResult};
5-
use crate::sockets::util::{
6-
get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address,
7-
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
8-
set_unicast_hop_limit, udp_bind, udp_disconnect,
9-
};
5+
use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
106
use crate::sockets::{
11-
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView,
7+
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
128
};
139
use anyhow::anyhow;
1410
use async_trait::async_trait;
15-
use io_lifetimes::AsSocketlike;
16-
use rustix::io::Errno;
1711
use std::net::SocketAddr;
1812
use tokio::io::Interest;
1913
use wasmtime::component::Resource;
@@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
2822
network: Resource<Network>,
2923
local_address: IpSocketAddress,
3024
) -> SocketResult<()> {
31-
self.ctx.allowed_network_uses.check_allowed_udp()?;
32-
33-
match self.table.get(&this)?.udp_state {
34-
UdpState::Default => {}
35-
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
36-
UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()),
37-
}
38-
39-
// Set the socket addr check on the socket so later functions have access to it through the socket handle
25+
let local_address = SocketAddr::from(local_address);
4026
let check = self.table.get(&network)?.socket_addr_check.clone();
41-
self.table
42-
.get_mut(&this)?
43-
.socket_addr_check
44-
.replace(check.clone());
45-
46-
let socket = self.table.get(&this)?;
47-
let local_address: SocketAddr = local_address.into();
48-
49-
if !is_valid_address_family(local_address.ip(), socket.family) {
50-
return Err(ErrorCode::InvalidArgument.into());
51-
}
52-
53-
{
54-
check.check(local_address, SocketAddrUse::UdpBind).await?;
55-
56-
// Perform the OS bind call.
57-
udp_bind(socket.udp_socket(), local_address)?;
58-
}
27+
check.check(local_address, SocketAddrUse::UdpBind).await?;
5928

6029
let socket = self.table.get_mut(&this)?;
61-
socket.udp_state = UdpState::BindStarted;
30+
socket.bind(local_address)?;
31+
socket.set_socket_addr_check(Some(check));
6232

6333
Ok(())
6434
}
6535

6636
fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
67-
let socket = self.table.get_mut(&this)?;
68-
69-
match socket.udp_state {
70-
UdpState::BindStarted => {
71-
socket.udp_state = UdpState::Bound;
72-
Ok(())
73-
}
74-
_ => Err(ErrorCode::NotInProgress.into()),
75-
}
37+
self.table.get_mut(&this)?.finish_bind()?;
38+
Ok(())
7639
}
7740

7841
async fn stream(
@@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
9558
let socket = self.table.get_mut(&this)?;
9659
let remote_address = remote_address.map(SocketAddr::from);
9760

98-
match socket.udp_state {
99-
UdpState::Bound | UdpState::Connected => {}
100-
_ => return Err(ErrorCode::InvalidState.into()),
61+
if !socket.is_bound() {
62+
return Err(ErrorCode::InvalidState.into());
10163
}
10264

10365
// We disconnect & (re)connect in two distinct steps for two reasons:
@@ -107,48 +69,29 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
10769
// if there isn't a disconnect in between.
10870

10971
// Step #1: Disconnect
110-
if let UdpState::Connected = socket.udp_state {
111-
udp_disconnect(socket.udp_socket())?;
112-
socket.udp_state = UdpState::Bound;
72+
if socket.is_connected() {
73+
socket.disconnect()?;
11374
}
11475

11576
// Step #2: (Re)connect
11677
if let Some(connect_addr) = remote_address {
117-
let Some(check) = socket.socket_addr_check.as_ref() else {
78+
let Some(check) = socket.socket_addr_check() else {
11879
return Err(ErrorCode::InvalidState.into());
11980
};
120-
if !is_valid_remote_address(connect_addr)
121-
|| !is_valid_address_family(connect_addr.ip(), socket.family)
122-
{
123-
return Err(ErrorCode::InvalidArgument.into());
124-
}
12581
check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
126-
127-
rustix::net::connect(socket.udp_socket(), &connect_addr).map_err(
128-
|error| match error {
129-
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `bind` implementation.
130-
Errno::INPROGRESS => {
131-
tracing::debug!(
132-
"UDP connect returned EINPROGRESS, which should never happen"
133-
);
134-
ErrorCode::Unknown
135-
}
136-
_ => ErrorCode::from(error),
137-
},
138-
)?;
139-
socket.udp_state = UdpState::Connected;
82+
socket.connect(connect_addr)?;
14083
}
14184

14285
let incoming_stream = IncomingDatagramStream {
143-
inner: socket.inner.clone(),
86+
inner: socket.socket().clone(),
14487
remote_address,
14588
};
14689
let outgoing_stream = OutgoingDatagramStream {
147-
inner: socket.inner.clone(),
90+
inner: socket.socket().clone(),
14891
remote_address,
149-
family: socket.family,
92+
family: socket.address_family(),
15093
send_state: SendState::Idle,
151-
socket_addr_check: socket.socket_addr_check.clone(),
94+
socket_addr_check: socket.socket_addr_check().cloned(),
15295
};
15396

15497
Ok((
@@ -159,56 +102,25 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
159102

160103
fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
161104
let socket = self.table.get(&this)?;
162-
163-
match socket.udp_state {
164-
UdpState::Default => return Err(ErrorCode::InvalidState.into()),
165-
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
166-
_ => {}
167-
}
168-
169-
let addr = socket
170-
.udp_socket()
171-
.as_socketlike_view::<std::net::UdpSocket>()
172-
.local_addr()?;
173-
Ok(addr.into())
105+
Ok(socket.local_address()?.into())
174106
}
175107

176108
fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
177109
let socket = self.table.get(&this)?;
178-
179-
match socket.udp_state {
180-
UdpState::Connected => {}
181-
_ => return Err(ErrorCode::InvalidState.into()),
182-
}
183-
184-
let addr = socket
185-
.udp_socket()
186-
.as_socketlike_view::<std::net::UdpSocket>()
187-
.peer_addr()?;
188-
Ok(addr.into())
110+
Ok(socket.remote_address()?.into())
189111
}
190112

191113
fn address_family(
192114
&mut self,
193115
this: Resource<udp::UdpSocket>,
194116
) -> Result<IpAddressFamily, anyhow::Error> {
195117
let socket = self.table.get(&this)?;
196-
197-
match socket.family {
198-
SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
199-
SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
200-
}
118+
Ok(socket.address_family().into())
201119
}
202120

203121
fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
204122
let socket = self.table.get(&this)?;
205-
206-
let ttl = match socket.family {
207-
SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?,
208-
SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?,
209-
};
210-
211-
Ok(ttl)
123+
Ok(socket.unicast_hop_limit()?)
212124
}
213125

214126
fn set_unicast_hop_limit(
@@ -217,17 +129,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
217129
value: u8,
218130
) -> SocketResult<()> {
219131
let socket = self.table.get(&this)?;
220-
221-
set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?;
222-
132+
socket.set_unicast_hop_limit(value)?;
223133
Ok(())
224134
}
225135

226136
fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
227137
let socket = self.table.get(&this)?;
228-
229-
let value = receive_buffer_size(socket.udp_socket())?;
230-
Ok(value)
138+
Ok(socket.receive_buffer_size()?)
231139
}
232140

233141
fn set_receive_buffer_size(
@@ -236,33 +144,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
236144
value: u64,
237145
) -> SocketResult<()> {
238146
let socket = self.table.get(&this)?;
239-
240-
set_receive_buffer_size(socket.udp_socket(), value)?;
147+
socket.set_receive_buffer_size(value)?;
241148
Ok(())
242149
}
243150

244151
fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
245152
let socket = self.table.get(&this)?;
246-
247-
let value = send_buffer_size(socket.udp_socket())?;
248-
Ok(value)
153+
Ok(socket.send_buffer_size()?)
249154
}
250155

251-
fn set_send_buffer_size(
252-
&mut self,
253-
this: Resource<udp::UdpSocket>,
254-
value: u64,
255-
) -> SocketResult<()> {
156+
fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
256157
let socket = self.table.get(&this)?;
257-
258-
set_send_buffer_size(socket.udp_socket(), value)?;
158+
socket.set_send_buffer_size(value)?;
259159
Ok(())
260160
}
261161

262-
fn subscribe(
263-
&mut self,
264-
this: Resource<udp::UdpSocket>,
265-
) -> anyhow::Result<Resource<DynPollable>> {
162+
fn subscribe(&mut self, this: Resource<UdpSocket>) -> anyhow::Result<Resource<DynPollable>> {
266163
wasmtime_wasi_io::poll::subscribe(self.table, this)
267164
}
268165

@@ -276,6 +173,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
276173
}
277174
}
278175

176+
#[async_trait]
177+
impl Pollable for UdpSocket {
178+
async fn ready(&mut self) {
179+
// None of the socket-level operations block natively
180+
}
181+
}
182+
279183
impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
280184
fn receive(
281185
&mut self,
@@ -504,6 +408,15 @@ impl Pollable for OutgoingDatagramStream {
504408
}
505409
}
506410

411+
impl From<SocketAddressFamily> for IpAddressFamily {
412+
fn from(family: SocketAddressFamily) -> IpAddressFamily {
413+
match family {
414+
SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
415+
SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
416+
}
417+
}
418+
}
419+
507420
pub mod sync {
508421
use wasmtime::component::Resource;
509422

crates/wasi/src/p2/host/udp_create_socket.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::p2::SocketResult;
22
use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket};
3-
use crate::p2::udp::UdpSocket;
3+
use crate::sockets::UdpSocket;
44
use crate::sockets::WasiSocketsCtxView;
55
use wasmtime::component::Resource;
66

@@ -9,7 +9,7 @@ impl udp_create_socket::Host for WasiSocketsCtxView<'_> {
99
&mut self,
1010
address_family: IpAddressFamily,
1111
) -> SocketResult<Resource<UdpSocket>> {
12-
let socket = UdpSocket::new(address_family.into())?;
12+
let socket = UdpSocket::new(self.ctx, address_family.into())?;
1313
let socket = self.table.push(socket)?;
1414
Ok(socket)
1515
}

crates/wasi/src/p2/network.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl From<crate::sockets::util::ErrorCode> for ErrorCode {
4848
crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset,
4949
crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted,
5050
crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge,
51+
crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress,
5152
}
5253
}
5354
}

0 commit comments

Comments
 (0)