Skip to content
Merged
1 change: 1 addition & 0 deletions data/src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct Config {
pub reliability_parameter: u32,
pub label: String,
pub protocol: String,
pub max_message_size: u32,
}

/// DataChannel represents a data channel
Expand Down
6 changes: 6 additions & 0 deletions sctp/src/association/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ pub struct Config {
pub struct Association {
name: String,
state: Arc<AtomicU8>,
// TODO: Convert into `u32`, as there is no reason why the `max_message_size` should need to be
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to keep refactors to the minimum in order to keep this PR as small as possible. If you do agree with this comment, I can tackle it in another PR, otherwise, I will remove said comment.

// changed after the Assocaition has been created. Note that even if there was a use case for
// this, it is not used anywhere in the code base.
//
// Using atomics where not necessary -- especially in a hot path such as `prepare_write` -- may
// negatively impact performance, and adds unneeded complexity to the code.
max_message_size: Arc<AtomicU32>,
inflight_queue_length: Arc<AtomicUsize>,
will_send_shutdown: Arc<AtomicBool>,
Expand Down
1 change: 1 addition & 0 deletions sdp/src/description/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub const ATTR_KEY_SEND_ONLY: &str = "sendonly";
pub const ATTR_KEY_SEND_RECV: &str = "sendrecv";
pub const ATTR_KEY_EXT_MAP: &str = "extmap";
pub const ATTR_KEY_EXTMAP_ALLOW_MIXED: &str = "extmap-allow-mixed";
pub const ATTR_KEY_MAX_MESSAGE_SIZE: &str = "max-message-size";

/// Constants for semantic tokens used in JSEP
pub const SEMANTIC_TOKEN_LIP_SYNCHRONIZATION: &str = "LS";
Expand Down
33 changes: 33 additions & 0 deletions webrtc/src/api/setting_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ pub struct ReplayProtection {
pub srtcp: usize,
}

#[derive(Clone)]
pub enum SctpMaxMessageSize {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am personally not a huge fan of how the crate makes use of 0 as a sentinel value for some configuration being unset, so I decided to make use of an enum to give the user a clear distiction between having a bounded SctpMaxMessageSize, or an unbounded one (which is typically represented as a zero in this code base).

Bounded(u32),
Unbounded,
}

impl SctpMaxMessageSize {
pub const DEFAULT_MESSAGE_SIZE: u32 = 65536;
pub fn as_u32(&self) -> u32 {
match self {
Self::Bounded(result) => *result,
Self::Unbounded => 0,
}
}
}

impl Default for SctpMaxMessageSize {
fn default() -> Self {
// https://datatracker.ietf.org/doc/html/rfc8841#section-6.1-4
// > If the SDP "max-message-size" attribute is not present, the default value is 64K.
Self::Bounded(Self::DEFAULT_MESSAGE_SIZE)
}
}

/// SettingEngine allows influencing behavior in ways that are not
/// supported by the WebRTC API. This allows us to support additional
/// use-cases without deviating from the WebRTC API elsewhere.
Expand All @@ -79,6 +103,8 @@ pub struct SettingEngine {
pub(crate) receive_mtu: usize,
pub(crate) mid_generator: Option<Arc<dyn Fn(isize) -> String + Send + Sync>>,
pub(crate) enable_sender_rtx: bool,
/// Determines the max size of any message that may be sent through an SCTP transport.
pub(crate) sctp_max_message_size_can_send: SctpMaxMessageSize,
}

impl SettingEngine {
Expand Down Expand Up @@ -342,4 +368,11 @@ impl SettingEngine {
pub fn enable_sender_rtx(&mut self, is_enabled: bool) {
self.enable_sender_rtx = is_enabled;
}

pub fn set_sctp_max_message_size_can_send(
&mut self,
max_message_size_can_send: SctpMaxMessageSize,
) {
self.sctp_max_message_size_can_send = max_message_size_can_send
}
}
198 changes: 198 additions & 0 deletions webrtc/src/data_channel/data_channel_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use waitgroup::WaitGroup;

use super::*;
use crate::api::media_engine::MediaEngine;
use crate::api::setting_engine::SctpMaxMessageSize;
use crate::api::{APIBuilder, API};
use crate::data_channel::data_channel_init::RTCDataChannelInit;
//use log::LevelFilter;
Expand Down Expand Up @@ -1375,6 +1376,203 @@ async fn test_data_channel_non_standard_session_description() -> Result<()> {
Ok(())
}

async fn create_data_channel_with_max_message_size(
remote_max_message_size: Option<u32>,
can_send_max_message_size: Option<SctpMaxMessageSize>,
) -> Result<Arc<RTCDataChannel>> {
let mut m = MediaEngine::default();
let mut s: SettingEngine = SettingEngine::default();
s.detach_data_channels();
m.register_default_codecs()?;
let api_builder = APIBuilder::new().with_media_engine(m);

if let Some(can_send_max_message_size) = can_send_max_message_size {
s.set_sctp_max_message_size_can_send(can_send_max_message_size);
}

let api = api_builder.with_setting_engine(s).build();

let (offer_pc, answer_pc) = new_pair(&api).await?;
let (data_channel_tx, mut data_channel_rx) = mpsc::channel::<Arc<RTCDataChannel>>(1);
let data_channel_tx = Arc::new(data_channel_tx);
answer_pc.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
let data_channel_tx2 = Arc::clone(&data_channel_tx);
Box::pin(async move {
data_channel_tx2.send(dc).await.unwrap();
})
}));

let _ = offer_pc.create_data_channel("foo", None).await?;

let offer = offer_pc.create_offer(None).await?;
let mut offer_gathering_complete = offer_pc.gathering_complete_promise().await;
offer_pc.set_local_description(offer).await?;
let _ = offer_gathering_complete.recv().await;
let mut offer = offer_pc.local_description().await.unwrap();

if let Some(remote_max_message_size) = remote_max_message_size {
offer
.sdp
.push_str(format!("a=max-message-size:{}\r\n", remote_max_message_size).as_str());
}

answer_pc.set_remote_description(offer).await?;

let answer = answer_pc.create_answer(None).await?;

let mut answer_gathering_complete = answer_pc.gathering_complete_promise().await;
answer_pc.set_local_description(answer).await?;
let _ = answer_gathering_complete.recv().await;

let answer = answer_pc.local_description().await.unwrap();
offer_pc.set_remote_description(answer).await?;

Ok(data_channel_rx.recv().await.unwrap())
}

// 128 KB
const EXPECTED_MAX_MESSAGE_SIZE: u32 = 131072;

#[tokio::test]
async fn test_data_channel_max_message_size_respected_on_send() -> Result<()> {
let data_channel = create_data_channel_with_max_message_size(
Some(EXPECTED_MAX_MESSAGE_SIZE),
Some(SctpMaxMessageSize::Unbounded),
)
.await?;

// A buffer with a size greater than the default size of 64KB.
let buffer = vec![0; 68000];
let bytes = bytes::Bytes::copy_from_slice(buffer.as_slice());
data_channel.send(&bytes).await.unwrap();

Ok(())
}

#[tokio::test]
async fn test_given_remote_max_message_size_is_none_when_data_channel_can_send_max_message_size_respected_on_send(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like using the given, when, then naming pattern for certain tests. If you're not a fan of this, feel free to suggest alternative test names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following tests whether the max-message-size attribute, or the sctp_max_message_size_can_send attribute in the SettingsEngine is actually respected (in an end to end manner).

) -> Result<()> {
const EXPECTED_CAN_SEND_MAX_MESSAGE_SIZE: u32 = 1024;
let data_channel = create_data_channel_with_max_message_size(
None,
Some(SctpMaxMessageSize::Bounded(
EXPECTED_CAN_SEND_MAX_MESSAGE_SIZE,
)),
)
.await?;

let buffer = vec![0; 65536];
let bytes = bytes::Bytes::copy_from_slice(buffer.as_slice());

let actual = data_channel.send(&bytes).await;

assert!(matches!(
actual,
Err(Error::Data(data::Error::Sctp(
sctp::Error::ErrOutboundPacketTooLarge
)))
));

Ok(())
}

async fn run_data_channel_config_max_message_size(
remote_max_message_size: Option<u32>,
can_send_max_message_size: Option<SctpMaxMessageSize>,
) -> Result<u32> {
let data_channel = create_data_channel_with_max_message_size(
remote_max_message_size,
can_send_max_message_size,
)
.await?;
let data_channel = data_channel.detach().await?;
Ok(data_channel.config.max_message_size)
}

#[tokio::test]
async fn test_data_channel_max_message_size_reflected_on_data_channel_config() -> Result<()> {
assert_eq!(
run_data_channel_config_max_message_size(
Some(EXPECTED_MAX_MESSAGE_SIZE),
Some(SctpMaxMessageSize::Unbounded)
)
.await?,
EXPECTED_MAX_MESSAGE_SIZE
);

Ok(())
}

#[tokio::test]
async fn test_can_send_max_message_size_unspecified_then_remote_default_value_is_respected(
) -> Result<()> {
assert_eq!(
run_data_channel_config_max_message_size(Some(EXPECTED_MAX_MESSAGE_SIZE), None).await?,
SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE
);

Ok(())
}

#[tokio::test]
async fn test_given_can_send_channel_max_message_size_less_than_remote_max_message_size_respect_send_channel_max_message_size(
) -> Result<()> {
let remote_max_message_size = 1024;
let can_send_channel_max_message_size = 256;
assert_eq!(
run_data_channel_config_max_message_size(
Some(remote_max_message_size),
Some(SctpMaxMessageSize::Bounded(
can_send_channel_max_message_size
))
)
.await?,
can_send_channel_max_message_size
);

Ok(())
}

#[tokio::test]
async fn test_can_send_max_message_size_respected_on_data_channel_config() -> Result<()> {
let can_send_channel_max_message_size = 1024;
assert_eq!(
run_data_channel_config_max_message_size(
None,
Some(SctpMaxMessageSize::Bounded(
can_send_channel_max_message_size
))
)
.await?,
can_send_channel_max_message_size
);

Ok(())
}

#[tokio::test]
async fn test_given_no_remote_message_size_or_can_send_max_message_size_max_size_is_65536(
) -> Result<()> {
assert_eq!(
run_data_channel_config_max_message_size(None, None).await?,
SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE
);

Ok(())
}

#[tokio::test]
async fn test_respect_default_remote_max_message_size_when_can_send_max_message_size_is_greater_than_default(
) -> Result<()> {
assert_eq!(
run_data_channel_config_max_message_size(None, Some(SctpMaxMessageSize::Bounded(70000)))
.await?,
SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE
);

Ok(())
}

struct TestOrtcStack {
//api *API
gatherer: Arc<RTCIceGatherer>,
Expand Down
1 change: 1 addition & 0 deletions webrtc/src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ impl RTCDataChannel {
label: self.label.clone(),
protocol: self.protocol.clone(),
negotiated: self.negotiated,
max_message_size: association.max_message_size(),
};

if !self.negotiated {
Expand Down
27 changes: 18 additions & 9 deletions webrtc/src/peer_connection/peer_connection_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::VecDeque;
use std::sync::Weak;

use super::*;
use crate::api::setting_engine::SctpMaxMessageSize;
use crate::rtp_transceiver::{create_stream_info, PayloadType};
use crate::stats::stats_collector::StatsCollector;
use crate::stats::{
Expand Down Expand Up @@ -290,7 +291,16 @@ impl PeerConnectionInternal {
if let Some(remote_port) = get_application_media_section_sctp_port(parsed_remote) {
if let Some(local_port) = get_application_media_section_sctp_port(parsed_local)
{
self.start_sctp(local_port, remote_port).await;
// TODO: Reuse the MediaDescription retrieved when looking for the message size.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small nitpick of the code from my end, if you agree with the proposal, which is retrieving the MediaDescription once, and reusing it, I can implement this in another PR.

let max_message_size =
get_application_media_section_max_message_size(parsed_remote)
.unwrap_or(SctpMaxMessageSize::DEFAULT_MESSAGE_SIZE);
self.start_sctp(
local_port,
remote_port,
SCTPTransportCapabilities { max_message_size },
)
.await;
}
}
}
Expand Down Expand Up @@ -460,17 +470,16 @@ impl PeerConnectionInternal {
}

/// Start SCTP subsystem
async fn start_sctp(&self, local_port: u16, remote_port: u16) {
async fn start_sctp(
&self,
local_port: u16,
remote_port: u16,
sctp_transport_capabilities: SCTPTransportCapabilities,
) {
// Start sctp
if let Err(err) = self
.sctp_transport
.start(
SCTPTransportCapabilities {
max_message_size: 0,
},
local_port,
remote_port,
)
.start(sctp_transport_capabilities, local_port, remote_port)
.await
{
log::warn!("Failed to start SCTP: {err}");
Expand Down
24 changes: 16 additions & 8 deletions webrtc/src/peer_connection/sdp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,15 @@ pub(crate) fn get_application_media_section_sctp_port(desc: &SessionDescription)
None
}

pub(crate) fn get_application_media_section_max_message_size(
desc: &SessionDescription,
) -> Option<u32> {
get_application_media(desc)?
.attribute(ATTR_KEY_MAX_MESSAGE_SIZE)??
.parse()
.ok()
}

pub(crate) fn get_by_mid<'a>(
search_mid: &str,
desc: &'a session_description::RTCSessionDescription,
Expand All @@ -1065,18 +1074,17 @@ pub(crate) fn get_by_mid<'a>(
None
}

pub(crate) fn get_application_media(desc: &SessionDescription) -> Option<&MediaDescription> {
desc.media_descriptions
.iter()
.find(|media_description| media_description.media_name.media == MEDIA_SECTION_APPLICATION)
}

/// have_data_channel return MediaDescription with MediaName equal application
pub(crate) fn have_data_channel(
desc: &session_description::RTCSessionDescription,
) -> Option<&MediaDescription> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rewritten this function in particular to use get_application_media, but I did not refactor the other functions in this file which could also make use of get_application_media in order to keep this PR to a minimum.

If you agree with the refactor, I can do it in another PR.

if let Some(parsed) = &desc.parsed {
for d in &parsed.media_descriptions {
if d.media_name.media == MEDIA_SECTION_APPLICATION {
return Some(d);
}
}
}
None
get_application_media(desc.parsed.as_ref()?)
}

pub(crate) fn codecs_from_media_description(
Expand Down
Loading
Loading