diff --git a/data/src/data_channel/mod.rs b/data/src/data_channel/mod.rs index 1b1046cca..e2c28d02c 100644 --- a/data/src/data_channel/mod.rs +++ b/data/src/data_channel/mod.rs @@ -34,6 +34,7 @@ pub struct Config { pub reliability_parameter: u32, pub label: String, pub protocol: String, + pub max_message_size: u32, } /// DataChannel represents a data channel diff --git a/sctp/src/association/mod.rs b/sctp/src/association/mod.rs index 0d21bc105..1135b6d9d 100644 --- a/sctp/src/association/mod.rs +++ b/sctp/src/association/mod.rs @@ -202,6 +202,12 @@ pub struct Config { pub struct Association { name: String, state: Arc, + // TODO: Convert into `u32`, as there is no reason why the `max_message_size` should need to be + // changed after the Assocaition has been created. Note that even if there was a use case for + // this, it is not used anywhere in the code base. + // + // Using atomics where not necessary -- especially in a hot path such as `prepare_write` -- may + // negatively impact performance, and adds unneeded complexity to the code. max_message_size: Arc, inflight_queue_length: Arc, will_send_shutdown: Arc, diff --git a/sdp/src/description/session.rs b/sdp/src/description/session.rs index ea6439c3b..a17809e89 100644 --- a/sdp/src/description/session.rs +++ b/sdp/src/description/session.rs @@ -31,6 +31,7 @@ pub const ATTR_KEY_SEND_ONLY: &str = "sendonly"; pub const ATTR_KEY_SEND_RECV: &str = "sendrecv"; pub const ATTR_KEY_EXT_MAP: &str = "extmap"; pub const ATTR_KEY_EXTMAP_ALLOW_MIXED: &str = "extmap-allow-mixed"; +pub const ATTR_KEY_MAX_MESSAGE_SIZE: &str = "max-message-size"; /// Constants for semantic tokens used in JSEP pub const SEMANTIC_TOKEN_LIP_SYNCHRONIZATION: &str = "LS"; diff --git a/webrtc/src/api/setting_engine/mod.rs b/webrtc/src/api/setting_engine/mod.rs index 17170a706..0c0394a8b 100644 --- a/webrtc/src/api/setting_engine/mod.rs +++ b/webrtc/src/api/setting_engine/mod.rs @@ -54,6 +54,30 @@ pub struct ReplayProtection { pub srtcp: usize, } +#[derive(Clone)] +pub enum SctpMaxMessageSize { + Bounded(u32), + Unbounded, +} + +impl SctpMaxMessageSize { + pub const DEFAULT_MESSAGE_SIZE: u32 = 65536; + pub fn as_u32(&self) -> u32 { + match self { + Self::Bounded(result) => *result, + Self::Unbounded => 0, + } + } +} + +impl Default for SctpMaxMessageSize { + fn default() -> Self { + // https://datatracker.ietf.org/doc/html/rfc8841#section-6.1-4 + // > If the SDP "max-message-size" attribute is not present, the default value is 64K. + Self::Bounded(Self::DEFAULT_MESSAGE_SIZE) + } +} + /// SettingEngine allows influencing behavior in ways that are not /// supported by the WebRTC API. This allows us to support additional /// use-cases without deviating from the WebRTC API elsewhere. @@ -79,6 +103,8 @@ pub struct SettingEngine { pub(crate) receive_mtu: usize, pub(crate) mid_generator: Option String + Send + Sync>>, pub(crate) enable_sender_rtx: bool, + /// Determines the max size of any message that may be sent through an SCTP transport. + pub(crate) sctp_max_message_size_can_send: SctpMaxMessageSize, } impl SettingEngine { @@ -342,4 +368,11 @@ impl SettingEngine { pub fn enable_sender_rtx(&mut self, is_enabled: bool) { self.enable_sender_rtx = is_enabled; } + + pub fn set_sctp_max_message_size_can_send( + &mut self, + max_message_size_can_send: SctpMaxMessageSize, + ) { + self.sctp_max_message_size_can_send = max_message_size_can_send + } } diff --git a/webrtc/src/data_channel/data_channel_test.rs b/webrtc/src/data_channel/data_channel_test.rs index 276a7f182..3998768ed 100644 --- a/webrtc/src/data_channel/data_channel_test.rs +++ b/webrtc/src/data_channel/data_channel_test.rs @@ -8,6 +8,7 @@ use waitgroup::WaitGroup; use super::*; use crate::api::media_engine::MediaEngine; +use crate::api::setting_engine::SctpMaxMessageSize; use crate::api::{APIBuilder, API}; use crate::data_channel::data_channel_init::RTCDataChannelInit; //use log::LevelFilter; @@ -1375,6 +1376,203 @@ async fn test_data_channel_non_standard_session_description() -> Result<()> { Ok(()) } +async fn create_data_channel_with_max_message_size( + remote_max_message_size: Option, + can_send_max_message_size: Option, +) -> Result> { + let mut m = MediaEngine::default(); + let mut s: SettingEngine = SettingEngine::default(); + s.detach_data_channels(); + m.register_default_codecs()?; + let api_builder = APIBuilder::new().with_media_engine(m); + + if let Some(can_send_max_message_size) = can_send_max_message_size { + s.set_sctp_max_message_size_can_send(can_send_max_message_size); + } + + let api = api_builder.with_setting_engine(s).build(); + + let (offer_pc, answer_pc) = new_pair(&api).await?; + let (data_channel_tx, mut data_channel_rx) = mpsc::channel::>(1); + let data_channel_tx = Arc::new(data_channel_tx); + answer_pc.on_data_channel(Box::new(move |dc: Arc| { + let data_channel_tx2 = Arc::clone(&data_channel_tx); + Box::pin(async move { + data_channel_tx2.send(dc).await.unwrap(); + }) + })); + + let _ = offer_pc.create_data_channel("foo", None).await?; + + let offer = offer_pc.create_offer(None).await?; + let mut offer_gathering_complete = offer_pc.gathering_complete_promise().await; + offer_pc.set_local_description(offer).await?; + let _ = offer_gathering_complete.recv().await; + let mut offer = offer_pc.local_description().await.unwrap(); + + if let Some(remote_max_message_size) = remote_max_message_size { + offer + .sdp + .push_str(format!("a=max-message-size:{}\r\n", remote_max_message_size).as_str()); + } + + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + + let mut answer_gathering_complete = answer_pc.gathering_complete_promise().await; + answer_pc.set_local_description(answer).await?; + let _ = answer_gathering_complete.recv().await; + + let answer = answer_pc.local_description().await.unwrap(); + offer_pc.set_remote_description(answer).await?; + + Ok(data_channel_rx.recv().await.unwrap()) +} + +// 128 KB +const EXPECTED_MAX_MESSAGE_SIZE: u32 = 131072; + +#[tokio::test] +async fn test_data_channel_max_message_size_respected_on_send() -> Result<()> { + let data_channel = create_data_channel_with_max_message_size( + Some(EXPECTED_MAX_MESSAGE_SIZE), + Some(SctpMaxMessageSize::Unbounded), + ) + .await?; + + // A buffer with a size greater than the default size of 64KB. + let buffer = vec![0; 68000]; + let bytes = bytes::Bytes::copy_from_slice(buffer.as_slice()); + data_channel.send(&bytes).await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_given_remote_max_message_size_is_none_when_data_channel_can_send_max_message_size_respected_on_send( +) -> Result<()> { + const EXPECTED_CAN_SEND_MAX_MESSAGE_SIZE: u32 = 1024; + let data_channel = create_data_channel_with_max_message_size( + None, + Some(SctpMaxMessageSize::Bounded( + EXPECTED_CAN_SEND_MAX_MESSAGE_SIZE, + )), + ) + .await?; + + let buffer = vec![0; 65536]; + let bytes = bytes::Bytes::copy_from_slice(buffer.as_slice()); + + let actual = data_channel.send(&bytes).await; + + assert!(matches!( + actual, + Err(Error::Data(data::Error::Sctp( + sctp::Error::ErrOutboundPacketTooLarge + ))) + )); + + Ok(()) +} + +async fn run_data_channel_config_max_message_size( + remote_max_message_size: Option, + can_send_max_message_size: Option, +) -> Result { + let data_channel = create_data_channel_with_max_message_size( + remote_max_message_size, + can_send_max_message_size, + ) + .await?; + let data_channel = data_channel.detach().await?; + Ok(data_channel.config.max_message_size) +} + +#[tokio::test] +async fn test_data_channel_max_message_size_reflected_on_data_channel_config() -> Result<()> { + assert_eq!( + run_data_channel_config_max_message_size( + Some(EXPECTED_MAX_MESSAGE_SIZE), + Some(SctpMaxMessageSize::Unbounded) + ) + .await?, + EXPECTED_MAX_MESSAGE_SIZE + ); + + Ok(()) +} + +#[tokio::test] +async fn test_can_send_max_message_size_unspecified_then_remote_default_value_is_respected( +) -> Result<()> { + assert_eq!( + run_data_channel_config_max_message_size(Some(EXPECTED_MAX_MESSAGE_SIZE), None).await?, + SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE + ); + + Ok(()) +} + +#[tokio::test] +async fn test_given_can_send_channel_max_message_size_less_than_remote_max_message_size_respect_send_channel_max_message_size( +) -> Result<()> { + let remote_max_message_size = 1024; + let can_send_channel_max_message_size = 256; + assert_eq!( + run_data_channel_config_max_message_size( + Some(remote_max_message_size), + Some(SctpMaxMessageSize::Bounded( + can_send_channel_max_message_size + )) + ) + .await?, + can_send_channel_max_message_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_can_send_max_message_size_respected_on_data_channel_config() -> Result<()> { + let can_send_channel_max_message_size = 1024; + assert_eq!( + run_data_channel_config_max_message_size( + None, + Some(SctpMaxMessageSize::Bounded( + can_send_channel_max_message_size + )) + ) + .await?, + can_send_channel_max_message_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_given_no_remote_message_size_or_can_send_max_message_size_max_size_is_65536( +) -> Result<()> { + assert_eq!( + run_data_channel_config_max_message_size(None, None).await?, + SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE + ); + + Ok(()) +} + +#[tokio::test] +async fn test_respect_default_remote_max_message_size_when_can_send_max_message_size_is_greater_than_default( +) -> Result<()> { + assert_eq!( + run_data_channel_config_max_message_size(None, Some(SctpMaxMessageSize::Bounded(70000))) + .await?, + SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE + ); + + Ok(()) +} + struct TestOrtcStack { //api *API gatherer: Arc, diff --git a/webrtc/src/data_channel/mod.rs b/webrtc/src/data_channel/mod.rs index 336b0fc31..b23886192 100644 --- a/webrtc/src/data_channel/mod.rs +++ b/webrtc/src/data_channel/mod.rs @@ -173,6 +173,7 @@ impl RTCDataChannel { label: self.label.clone(), protocol: self.protocol.clone(), negotiated: self.negotiated, + max_message_size: association.max_message_size(), }; if !self.negotiated { diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index 6a4f8ae3e..722a39e4b 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -2,6 +2,7 @@ use std::collections::VecDeque; use std::sync::Weak; use super::*; +use crate::api::setting_engine::SctpMaxMessageSize; use crate::rtp_transceiver::{create_stream_info, PayloadType}; use crate::stats::stats_collector::StatsCollector; use crate::stats::{ @@ -290,7 +291,16 @@ impl PeerConnectionInternal { if let Some(remote_port) = get_application_media_section_sctp_port(parsed_remote) { if let Some(local_port) = get_application_media_section_sctp_port(parsed_local) { - self.start_sctp(local_port, remote_port).await; + // TODO: Reuse the MediaDescription retrieved when looking for the message size. + let max_message_size = + get_application_media_section_max_message_size(parsed_remote) + .unwrap_or(SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE); + self.start_sctp( + local_port, + remote_port, + SCTPTransportCapabilities { max_message_size }, + ) + .await; } } } @@ -460,17 +470,16 @@ impl PeerConnectionInternal { } /// Start SCTP subsystem - async fn start_sctp(&self, local_port: u16, remote_port: u16) { + async fn start_sctp( + &self, + local_port: u16, + remote_port: u16, + sctp_transport_capabilities: SCTPTransportCapabilities, + ) { // Start sctp if let Err(err) = self .sctp_transport - .start( - SCTPTransportCapabilities { - max_message_size: 0, - }, - local_port, - remote_port, - ) + .start(sctp_transport_capabilities, local_port, remote_port) .await { log::warn!("Failed to start SCTP: {err}"); diff --git a/webrtc/src/peer_connection/sdp/mod.rs b/webrtc/src/peer_connection/sdp/mod.rs index 3532404e2..7267e9167 100644 --- a/webrtc/src/peer_connection/sdp/mod.rs +++ b/webrtc/src/peer_connection/sdp/mod.rs @@ -1049,6 +1049,15 @@ pub(crate) fn get_application_media_section_sctp_port(desc: &SessionDescription) None } +pub(crate) fn get_application_media_section_max_message_size( + desc: &SessionDescription, +) -> Option { + get_application_media(desc)? + .attribute(ATTR_KEY_MAX_MESSAGE_SIZE)?? + .parse() + .ok() +} + pub(crate) fn get_by_mid<'a>( search_mid: &str, desc: &'a session_description::RTCSessionDescription, @@ -1065,18 +1074,17 @@ pub(crate) fn get_by_mid<'a>( None } +pub(crate) fn get_application_media(desc: &SessionDescription) -> Option<&MediaDescription> { + desc.media_descriptions + .iter() + .find(|media_description| media_description.media_name.media == MEDIA_SECTION_APPLICATION) +} + /// have_data_channel return MediaDescription with MediaName equal application pub(crate) fn have_data_channel( desc: &session_description::RTCSessionDescription, ) -> Option<&MediaDescription> { - if let Some(parsed) = &desc.parsed { - for d in &parsed.media_descriptions { - if d.media_name.media == MEDIA_SECTION_APPLICATION { - return Some(d); - } - } - } - None + get_application_media(desc.parsed.as_ref()?) } pub(crate) fn codecs_from_media_description( diff --git a/webrtc/src/sctp_transport/mod.rs b/webrtc/src/sctp_transport/mod.rs index 51313702f..a498bde54 100644 --- a/webrtc/src/sctp_transport/mod.rs +++ b/webrtc/src/sctp_transport/mod.rs @@ -69,10 +69,6 @@ pub struct RTCSctpTransport { // so we need a dedicated field is_started: AtomicBool, - // max_message_size represents the maximum size of data that can be passed to - // DataChannel's send() method. - max_message_size: usize, - // max_channels represents the maximum amount of DataChannel's that can // be used simultaneously. max_channels: u16, @@ -103,7 +99,6 @@ impl RTCSctpTransport { dtls_transport, state: AtomicU8::new(RTCSctpTransportState::Connecting as u8), is_started: AtomicBool::new(false), - max_message_size: RTCSctpTransport::calc_message_size(65536, 65536), max_channels: SCTP_MAX_CHANNELS, sctp_association: Mutex::new(None), on_error_handler: Arc::new(ArcSwapOption::empty()), @@ -138,7 +133,7 @@ impl RTCSctpTransport { /// a connection over SCTP. pub async fn start( &self, - _remote_caps: SCTPTransportCapabilities, + remote_caps: SCTPTransportCapabilities, local_port: u16, remote_port: u16, ) -> Result<()> { @@ -148,6 +143,12 @@ impl RTCSctpTransport { self.is_started.store(true, Ordering::SeqCst); let dtls_transport = self.transport(); + + let max_message_size = Self::calc_message_size( + remote_caps.max_message_size, + self.setting_engine.sctp_max_message_size_can_send.as_u32(), + ); + if let Some(net_conn) = &dtls_transport.conn().await { let sctp_association = loop { tokio::select! { @@ -162,7 +163,7 @@ impl RTCSctpTransport { association = sctp::association::Association::client(sctp::association::Config { net_conn: Arc::clone(net_conn) as Arc, max_receive_buffer_size: 0, - max_message_size: 0, + max_message_size, name: String::new(), local_port, remote_port, @@ -232,7 +233,10 @@ impl RTCSctpTransport { _ = param.notify_rx.notified() => break, result = DataChannel::accept( ¶m.sctp_association, - data::data_channel::Config::default(), + data::data_channel::Config { + max_message_size: param.sctp_association.max_message_size(), + ..data::data_channel::Config::default() + }, &existing_data_channels, ) => { match result { @@ -338,9 +342,9 @@ impl RTCSctpTransport { .store(Some(Arc::new(Mutex::new(f)))); } - fn calc_message_size(remote_max_message_size: usize, can_send_size: usize) -> usize { + fn calc_message_size(remote_max_message_size: u32, can_send_size: u32) -> u32 { if remote_max_message_size == 0 && can_send_size == 0 { - usize::MAX + u32::MAX } else if remote_max_message_size == 0 { can_send_size } else if can_send_size == 0 || can_send_size > remote_max_message_size {