diff --git a/Cargo.lock b/Cargo.lock index 8d431b62..a25a768b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2140,6 +2140,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "documented" version = "0.9.2" @@ -2194,6 +2200,7 @@ dependencies = [ "nix", "parcelona", "pin-project", + "proxy-protocol", "ra-rpc", "ra-tls", "rand 0.8.5", @@ -5288,6 +5295,16 @@ dependencies = [ "prost 0.13.5", ] +[[package]] +name = "proxy-protocol" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e50c72c21c738f5c5f350cc33640aee30bf7cd20f9d9da20ed41bce2671d532" +dependencies = [ + "bytes", + "snafu", +] + [[package]] name = "prpc" version = "0.6.0" @@ -6845,6 +6862,27 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab12d3c261b2308b0d80c26fffb58d17eba81a4be97890101f416b478c79ca7" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1508efa03c362e23817f96cde18abed596a25219a8b2c66e8db33c03543d315b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.5.10" diff --git a/Cargo.toml b/Cargo.toml index 302295de..0909870a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -218,3 +218,4 @@ serde_yaml2 = "0.1.2" luks2 = "0.5.0" scopeguard = "1.2.0" +proxy-protocol = "0.5.0" diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index d150126c..5579cf0b 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -50,6 +50,7 @@ reqwest = { workspace = true, features = ["json"] } hyper = { workspace = true, features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio"] } jemallocator.workspace = true +proxy-protocol.workspace = true [target.'cfg(unix)'.dependencies] nix = { workspace = true, features = ["resource"] } diff --git a/gateway/gateway.toml b/gateway/gateway.toml index a89ff348..d2007119 100644 --- a/gateway/gateway.toml +++ b/gateway/gateway.toml @@ -67,6 +67,7 @@ app_address_ns_prefix = "_dstack-app-address" app_address_ns_compat = true workers = 32 external_port = 443 +inbound_pp_enabled = false [core.proxy.timeouts] # Timeout for establishing a connection to the target app. @@ -88,6 +89,8 @@ write = "5s" shutdown = "5s" # Timeout for total connection duration. total = "5h" +# Timeout for proxy protocol header +pp_header = "5s" [core.recycle] enabled = true diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 3a3d88db..9eef5eb0 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -85,6 +85,7 @@ pub struct ProxyConfig { pub workers: usize, pub app_address_ns_prefix: String, pub app_address_ns_compat: bool, + pub inbound_pp_enabled: bool, } #[derive(Debug, Clone, Deserialize)] @@ -106,6 +107,8 @@ pub struct Timeouts { pub write: Duration, #[serde(with = "serde_duration")] pub shutdown: Duration, + #[serde(with = "serde_duration")] + pub pp_header: Duration, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 61d25632..805fe285 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -21,6 +21,7 @@ mod admin_service; mod config; mod main_service; mod models; +mod pp; mod proxy; mod web_routes; diff --git a/gateway/src/pp.rs b/gateway/src/pp.rs new file mode 100644 index 00000000..aa61fdab --- /dev/null +++ b/gateway/src/pp.rs @@ -0,0 +1,175 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use std::net::SocketAddr; + +use anyhow::{bail, Context, Result}; +use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader}; +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + net::TcpStream, +}; + +use crate::config::ProxyConfig; + +const V1_PROTOCOL_PREFIX: &str = "PROXY"; +const V1_PREFIX_LEN: usize = 5; +const V1_MAX_LENGTH: usize = 107; +const V1_TERMINATOR: &[u8] = b"\r\n"; + +const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n"; +const V2_PREFIX_LEN: usize = 12; +const V2_MINIMUM_LEN: usize = 16; +const V2_LENGTH_INDEX: usize = 14; +const READ_BUFFER_LEN: usize = 512; +const V2_MAX_LENGTH: usize = 2048; + +pub(crate) async fn get_inbound_pp_header( + inbound: TcpStream, + config: &ProxyConfig, +) -> Result<(TcpStream, ProxyHeader)> { + if config.inbound_pp_enabled { + read_proxy_header(inbound).await + } else { + let header = create_inbound_pp_header(&inbound); + Ok((inbound, header)) + } +} + +pub struct DisplayAddr<'a>(pub &'a ProxyHeader); + +impl std::fmt::Display for DisplayAddr<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ProxyHeader::Version2 { addresses, .. } => match addresses { + v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v2::ProxyAddresses::Unix { .. } => write!(f, ""), + v2::ProxyAddresses::Unspec => write!(f, ""), + }, + ProxyHeader::Version1 { addresses, .. } => match addresses { + v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), + v1::ProxyAddresses::Unknown => write!(f, ""), + }, + _ => write!(f, ""), + } + } +} + +fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader { + // When PROXY protocol is disabled, create a synthetic header from the actual TCP connection + let peer_addr = inbound.peer_addr().ok(); + let local_addr = inbound.local_addr().ok(); + + match (peer_addr, local_addr) { + (Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv4 { + source, + destination, + }, + } + } + (Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => { + ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Ipv6 { + source, + destination, + }, + } + } + _ => ProxyHeader::Version2 { + command: v2::ProxyCommand::Proxy, + transport_protocol: v2::ProxyTransportProtocol::Stream, + addresses: v2::ProxyAddresses::Unspec, + }, + } +} + +async fn read_proxy_header(mut stream: I) -> Result<(I, ProxyHeader)> +where + I: AsyncRead + Unpin, +{ + let mut buffer = [0; READ_BUFFER_LEN]; + let mut dynamic_buffer = None; + + stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; + + if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() { + read_v1_header(&mut stream, &mut buffer).await?; + } else { + stream + .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) + .await?; + if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX { + dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; + } else { + bail!("No valid Proxy Protocol header detected"); + } + } + + let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); + + let header = + proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?; + Ok((stream, header)) +} + +async fn read_v2_header( + mut stream: I, + buffer: &mut [u8; READ_BUFFER_LEN], +) -> Result>> +where + I: AsyncRead + Unpin, +{ + let length = + u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; + let full_length = V2_MINIMUM_LEN + length; + + if full_length > V2_MAX_LENGTH { + bail!("V2 Proxy Protocol header is too long"); + } + + if full_length > READ_BUFFER_LEN { + let mut dynamic_buffer = Vec::with_capacity(full_length); + dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); + dynamic_buffer.resize(full_length, 0); + stream + .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(Some(dynamic_buffer)) + } else { + stream + .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(None) + } +} + +async fn read_v1_header(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()> +where + I: AsyncRead + Unpin, +{ + let mut end_found = false; + for i in V1_PREFIX_LEN..V1_MAX_LENGTH { + buffer[i] = stream.read_u8().await?; + + if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { + end_found = true; + break; + } + } + if !end_found { + bail!("No valid Proxy Protocol header detected"); + } + + Ok(()) +} diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 75cc286e..0aec6998 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -20,7 +20,12 @@ use tokio::{ }; use tracing::{debug, error, info, info_span, Instrument}; -use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter}; +use crate::{ + config::ProxyConfig, + main_service::Proxy, + models::EnteredCounter, + pp::{get_inbound_pp_header, DisplayAddr}, +}; #[derive(Debug, Clone)] pub(crate) struct AddressInfo { @@ -70,6 +75,79 @@ struct DstInfo { port: u16, is_tls: bool, is_h2: bool, + is_pp: bool, +} + +fn parse_app_addr(addr: &str) -> Result { + let (app_id, port_part) = addr + .rsplit_once('-') + .or_else(|| addr.rsplit_once(':')) + .unwrap_or((addr, "")); + if app_id.is_empty() { + bail!("app id is empty"); + } + let mut dst = DstInfo { + app_id: app_id.to_owned(), + port: 80, + is_tls: false, + is_h2: false, + is_pp: false, + }; + + if port_part.is_empty() { + return Ok(dst); + }; + + // Parse suffixes from right to left: g, s, p + let part_bytes = port_part.as_bytes(); + let mut end_idx = part_bytes.len(); + + // Parse from right to left until we hit a digit + while end_idx > 0 { + let ch = part_bytes[end_idx - 1] as char; + match ch { + c if c.is_ascii_digit() => { + break; + } + 'g' => { + if dst.is_h2 { + bail!("invalid app address: duplicate suffix 'g'"); + } + dst.is_h2 = true; + end_idx -= 1; + } + 's' => { + if dst.is_tls { + bail!("invalid app address: duplicate suffix 's'"); + } + dst.is_tls = true; + end_idx -= 1; + } + 'p' => { + if dst.is_pp { + bail!("invalid app address: duplicate suffix 'p'"); + } + dst.is_pp = true; + end_idx -= 1; + } + _ => { + bail!("invalid app address: unrecognized suffix character '{ch}'"); + } + } + } + + if dst.is_h2 && dst.is_tls { + bail!("invalid app address: both 's' and 'g' suffixes are present"); + } + + let port_str = &port_part[..end_idx]; + let port = if port_str.is_empty() { + None + } else { + Some(port_str.parse::().context("invalid port")?) + }; + dst.port = port.unwrap_or(if dst.is_tls { 443 } else { 80 }); + Ok(dst) } fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { @@ -80,63 +158,25 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { if subdomain.contains('.') { bail!("only one level of subdomain is supported, got sni={sni}, subdomain={subdomain}"); } - let mut parts = subdomain.split('-'); - let app_id = parts.next().context("no app id found")?.to_owned(); - if app_id.is_empty() { - bail!("app id is empty"); - } - let last_part = parts.next(); - let is_tls; - let port; - let is_h2; - match last_part { - None => { - is_tls = false; - is_h2 = false; - port = None; - } - Some(last_part) => { - let (port_str, has_g) = match last_part.strip_suffix('g') { - Some(without_g) => (without_g, true), - None => (last_part, false), - }; - - let (port_str, has_s) = match port_str.strip_suffix('s') { - Some(without_s) => (without_s, true), - None => (port_str, false), - }; - if has_g && has_s { - bail!("invalid sni format: `gs` is not allowed"); - } - is_h2 = has_g; - is_tls = has_s; - port = if port_str.is_empty() { - None - } else { - Some(port_str.parse::().context("invalid port")?) - }; - } - }; - let port = port.unwrap_or(if is_tls { 443 } else { 80 }); - if parts.next().is_some() { - bail!("invalid sni format"); - } - Ok(DstInfo { - app_id, - port, - is_tls, - is_h2, - }) + parse_app_addr(subdomain) } pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); async fn handle_connection( - mut inbound: TcpStream, + inbound: TcpStream, state: Proxy, dotted_base_domain: &str, ) -> Result<()> { let timeouts = &state.config.proxy.timeouts; + + let pp_timeout = timeouts.pp_header; + let pp_fut = get_inbound_pp_header(inbound, &state.config.proxy); + let (mut inbound, pp_header) = timeout(pp_timeout, pp_fut) + .await + .context("take proxy protocol header timeout")? + .context("failed to take proxy protocol header")?; + info!("client address: {}", DisplayAddr(&pp_header)); let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound)) .await .context("take sni timeout")? @@ -148,14 +188,12 @@ async fn handle_connection( let dst = parse_destination(&sni, dotted_base_domain)?; debug!("dst: {dst:?}"); if dst.is_tls { - tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await + tls_passthough::proxy_to_app(state, inbound, pp_header, buffer, &dst).await } else { - state - .proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2) - .await + state.proxy(inbound, pp_header, buffer, &dst).await } } else { - tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await + tls_passthough::proxy_with_sni(state, inbound, pp_header, buffer, &sni).await } } diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index e2cea9d0..c2004250 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -3,37 +3,24 @@ // SPDX-License-Identifier: Apache-2.0 use anyhow::{Context, Result}; -use std::fmt::Debug; +use proxy_protocol::ProxyHeader; use tokio::{io::AsyncWriteExt, net::TcpStream, task::JoinSet, time::timeout}; use tracing::{debug, info}; use crate::{ main_service::Proxy, models::{Counting, EnteredCounter}, + proxy::{parse_app_addr, DstInfo}, }; use super::{io_bridge::bridge, AddressGroup}; -#[derive(Debug)] -struct AppAddress { - app_id: String, - port: u16, -} - -impl AppAddress { - fn parse(data: &[u8]) -> Result { - // format: "3327603e03f5bd1f830812ca4a789277fc31f577:555" - let data = String::from_utf8(data.to_vec()).context("invalid app address")?; - let (app_id, port) = data.split_once(':').context("invalid app address")?; - Ok(Self { - app_id: app_id.to_string(), - port: port.parse().context("invalid port")?, - }) - } +fn parse_txt_addr(addr: &[u8]) -> Result { + parse_app_addr(core::str::from_utf8(addr).context("invalid app address")?) } /// resolve app address by sni -async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result { +async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result { let txt_domain = format!("{prefix}.{sni}"); let resolver = hickory_resolver::AsyncResolver::tokio_from_system_conf() .context("failed to create dns resolver")?; @@ -54,7 +41,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result Result, sni: &str, ) -> Result<()> { @@ -83,7 +71,7 @@ pub(crate) async fn proxy_with_sni( .await .context("failed to resolve app address")?; debug!("target address is {}:{}", addr.app_id, addr.port); - proxy_to_app(state, inbound, buffer, &addr.app_id, addr.port).await + proxy_to_app(state, inbound, pp_header, buffer, &addr).await } /// connect to multiple hosts simultaneously and return the first successful connection @@ -120,10 +108,12 @@ pub(crate) async fn connect_multiple_hosts( pub(crate) async fn proxy_to_app( state: Proxy, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, - app_id: &str, - port: u16, + dst: &DstInfo, ) -> Result<()> { + let app_id = &dst.app_id; + let port = dst.port; let addresses = state.lock().select_top_n_hosts(app_id)?; let (mut outbound, _counter) = timeout( state.config.proxy.timeouts.connect, @@ -132,6 +122,12 @@ pub(crate) async fn proxy_to_app( .await .with_context(|| format!("connecting timeout to app {app_id}: {addresses:?}:{port}"))? .with_context(|| format!("failed to connect to app {app_id}: {addresses:?}:{port}"))?; + + if dst.is_pp { + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } outbound .write_all(&buffer) .await diff --git a/gateway/src/proxy/tls_terminate.rs b/gateway/src/proxy/tls_terminate.rs index 9c159492..96ec5c8c 100644 --- a/gateway/src/proxy/tls_terminate.rs +++ b/gateway/src/proxy/tls_terminate.rs @@ -14,11 +14,12 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; +use proxy_protocol::ProxyHeader; use rustls::pki_types::pem::PemObject; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::version::{TLS12, TLS13}; use serde::Serialize; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _, ReadBuf}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; @@ -26,6 +27,7 @@ use tracing::{debug, info}; use crate::config::{CryptoProvider, ProxyConfig, TlsVersion}; use crate::main_service::Proxy; +use crate::proxy::DstInfo; use super::io_bridge::bridge; use super::tls_passthough::connect_multiple_hosts; @@ -296,14 +298,16 @@ impl Proxy { Ok(tls_stream) } - pub(crate) async fn proxy( + pub(super) async fn proxy( &self, inbound: TcpStream, + pp_header: ProxyHeader, buffer: Vec, - app_id: &str, - port: u16, - h2: bool, + dst: &DstInfo, ) -> Result<()> { + let app_id = &dst.app_id; + let port = dst.port; + let h2 = dst.is_h2; if app_id == "health" { return self.handle_health_check(inbound, buffer, port, h2).await; } @@ -316,13 +320,19 @@ impl Proxy { .with_context(|| format!("app {app_id} not found"))?; debug!("selected top n hosts: {addresses:?}"); let tls_stream = self.tls_accept(inbound, buffer, h2).await?; - let (outbound, _counter) = timeout( + let (mut outbound, _counter) = timeout( self.config.proxy.timeouts.connect, connect_multiple_hosts(addresses, port), ) .await .map_err(|_| anyhow!("connecting timeout"))? .context("failed to connect to app")?; + if dst.is_pp { + debug!("sending pp header: {pp_header:?}"); + let pp_header_bin = + proxy_protocol::encode(pp_header).context("failed to encode pp header")?; + outbound.write_all(&pp_header_bin).await?; + } bridge( IgnoreUnexpectedEofStream::new(tls_stream), outbound,