Skip to content

Commit 6dd86f1

Browse files
committed
feat: handle upstream binary messages to allow being run as a middleman
1 parent a434815 commit 6dd86f1

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

crates/websocket-proxy/src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,16 @@ async fn main() {
242242
let (send, _rec) = broadcast::channel(args.message_buffer_size);
243243
let sender = send.clone();
244244

245-
let listener = move |data: String| {
246-
trace!(message = "received data", data = data);
245+
let listener = move |data: Vec<u8>| {
246+
trace!(message = "received data", data = ?data);
247247
// Subtract one from receiver count, as we have to keep one receiver open at all times (see _rec)
248248
// to avoid the channel being closed. However this is not an active client connection.
249249
metrics_clone
250250
.active_connections
251251
.set((send.receiver_count() - 1) as f64);
252252

253253
let message_data = if args.enable_compression {
254-
let data_bytes = data.as_bytes();
254+
let data_bytes = data.as_slice();
255255
let mut compressed_data_bytes = Vec::new();
256256
{
257257
let mut compressor =
@@ -260,7 +260,7 @@ async fn main() {
260260
}
261261
compressed_data_bytes
262262
} else {
263-
data.into_bytes()
263+
data
264264
};
265265

266266
match send.send(message_data.into()) {

crates/websocket-proxy/src/subscriber.rs

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl Default for SubscriberOptions {
6363

6464
pub struct WebsocketSubscriber<F>
6565
where
66-
F: Fn(String) + Send + Sync + 'static,
66+
F: Fn(Vec<u8>) + Send + Sync + 'static,
6767
{
6868
uri: Uri,
6969
handler: F,
@@ -74,7 +74,7 @@ where
7474

7575
impl<F> WebsocketSubscriber<F>
7676
where
77-
F: Fn(String) + Send + Sync + 'static,
77+
F: Fn(Vec<u8>) + Send + Sync + 'static,
7878
{
7979
pub fn new(uri: Uri, handler: F, metrics: Arc<Metrics>, options: SubscriberOptions) -> Self {
8080
let backoff = ExponentialBackoff {
@@ -255,14 +255,17 @@ where
255255
);
256256
self.metrics
257257
.message_received_from_upstream(self.uri.to_string().as_str());
258-
(self.handler)(text.to_string());
258+
(self.handler)(text.as_bytes().to_vec());
259259
}
260260
Message::Binary(data) => {
261-
warn!(
262-
message = "received binary message, unsupported",
261+
trace!(
262+
message = "received binary message",
263263
uri = self.uri.to_string(),
264-
size = data.len()
264+
payload = ?data.as_ref()
265265
);
266+
self.metrics
267+
.message_received_from_upstream(self.uri.to_string().as_str());
268+
(self.handler)(data.as_ref().to_vec());
266269
}
267270
Message::Pong(_) => {
268271
trace!(
@@ -300,15 +303,15 @@ mod tests {
300303

301304
struct MockServer {
302305
addr: SocketAddr,
303-
message_sender: broadcast::Sender<String>,
306+
message_sender: broadcast::Sender<Vec<u8>>,
304307
shutdown: CancellationToken,
305308
}
306309

307310
impl MockServer {
308311
async fn new() -> Self {
309312
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
310313
let addr = listener.local_addr().unwrap();
311-
let (tx, _) = broadcast::channel::<String>(100);
314+
let (tx, _) = broadcast::channel::<Vec<u8>>(100);
312315
let shutdown = CancellationToken::new();
313316
let shutdown_clone = shutdown.clone();
314317
let tx_clone = tx.clone();
@@ -347,7 +350,7 @@ mod tests {
347350

348351
async fn handle_connection(
349352
stream: TcpStream,
350-
tx: broadcast::Sender<String>,
353+
tx: broadcast::Sender<Vec<u8>>,
351354
shutdown: CancellationToken,
352355
) {
353356
let ws_stream = match accept_async(stream).await {
@@ -369,8 +372,8 @@ mod tests {
369372
}
370373
msg = rx.recv() => {
371374
match msg {
372-
Ok(text) => {
373-
if let Err(e) = ws_sender.send(Message::Text(text.into())).await {
375+
Ok(data) => {
376+
if let Err(e) = ws_sender.send(data.into()).await {
374377
eprintln!("Error sending message: {}", e);
375378
break;
376379
}
@@ -386,9 +389,9 @@ mod tests {
386389

387390
async fn send_message(
388391
&self,
389-
msg: &str,
390-
) -> Result<usize, broadcast::error::SendError<String>> {
391-
self.message_sender.send(msg.to_string())
392+
msg: &[u8],
393+
) -> Result<usize, broadcast::error::SendError<Vec<u8>>> {
394+
self.message_sender.send(msg.to_vec())
392395
}
393396

394397
async fn shutdown(self) {
@@ -440,7 +443,7 @@ mod tests {
440443
}
441444
});
442445

443-
let listener_fn = move |_data: String| {
446+
let listener_fn = move |_data: Vec<u8>| {
444447
// Handler for received messages - not needed for this test
445448
};
446449

@@ -482,7 +485,7 @@ mod tests {
482485
let received_messages = Arc::new(Mutex::new(Vec::new()));
483486
let received_clone = received_messages.clone();
484487

485-
let listener = move |data: String| {
488+
let listener = move |data: Vec<u8>| {
486489
if let Ok(mut messages) = received_clone.lock() {
487490
messages.push(data);
488491
}
@@ -526,13 +529,21 @@ mod tests {
526529

527530
sleep(Duration::from_millis(500)).await;
528531

529-
let _ = server1.send_message("Message from server 1").await;
530-
let _ = server2.send_message("Message from server 2").await;
532+
let _ = server1
533+
.send_message("Message from server 1".as_bytes())
534+
.await;
535+
let _ = server2
536+
.send_message("Message from server 2".as_bytes())
537+
.await;
531538

532539
sleep(Duration::from_millis(500)).await;
533540

534-
let _ = server1.send_message("Another message from server 1").await;
535-
let _ = server2.send_message("Another message from server 2").await;
541+
let _ = server1
542+
.send_message("Another message from server 1".as_bytes())
543+
.await;
544+
let _ = server2
545+
.send_message("Another message from server 2".as_bytes())
546+
.await;
536547

537548
// Wait for messages to be processed
538549
sleep(Duration::from_millis(500)).await;
@@ -552,10 +563,10 @@ mod tests {
552563

553564
assert_eq!(messages.len(), 4);
554565

555-
assert!(messages.contains(&"Message from server 1".to_string()));
556-
assert!(messages.contains(&"Message from server 2".to_string()));
557-
assert!(messages.contains(&"Another message from server 1".to_string()));
558-
assert!(messages.contains(&"Another message from server 2".to_string()));
566+
assert!(messages.contains(&"Message from server 1".as_bytes().to_vec()));
567+
assert!(messages.contains(&"Message from server 2".as_bytes().to_vec()));
568+
assert!(messages.contains(&"Another message from server 1".as_bytes().to_vec()));
569+
assert!(messages.contains(&"Another message from server 2".as_bytes().to_vec()));
559570

560571
assert!(!messages.is_empty());
561572
}

0 commit comments

Comments
 (0)