Skip to main content

moqtap_proxy/
session.rs

1//! Per-connection proxy session — forwards streams between client and relay.
2
3use std::sync::Arc;
4
5use bytes::{Bytes, BytesMut};
6use tokio::task::JoinSet;
7use tokio_util::sync::CancellationToken;
8
9use moqtap_client::transport::quic::QuicTransport;
10use moqtap_client::transport::{RecvStream, SendStream, Transport};
11use moqtap_codec::dispatch::AnyDatagramHeader;
12use moqtap_codec::varint::VarInt;
13use moqtap_codec::version::DraftVersion;
14
15use crate::error::ProxyError;
16use crate::event::{ProxyEvent, ProxySide, SessionId};
17use crate::hook::ProxyHook;
18use crate::observer::ProxyObserver;
19use crate::parser::control::{ControlStreamParser, ParseResult};
20use crate::parser::data::{DataStreamParser, DataStreamType};
21
22/// The transport type for upstream relay connections.
23#[derive(Debug, Clone)]
24pub enum UpstreamTransportType {
25    /// Raw QUIC — `upstream_addr` is `host:port`.
26    Quic,
27    /// WebTransport — `url` is the full WebTransport URL.
28    WebTransport {
29        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
30        url: String,
31    },
32}
33
34/// Configuration for a proxy session's upstream connection.
35pub struct ProxySessionConfig {
36    /// The MoQT draft version to use for parsing.
37    pub draft: DraftVersion,
38    /// The transport type to use for the upstream connection.
39    pub upstream_transport: UpstreamTransportType,
40    /// Upstream relay address (e.g., `"192.168.1.10:4443"` for QUIC).
41    pub upstream_addr: String,
42    /// Whether to skip TLS verification for the upstream connection.
43    pub skip_upstream_cert_verify: bool,
44    /// Custom CA certificates for the upstream connection (DER-encoded).
45    pub upstream_ca_certs: Vec<Vec<u8>>,
46    /// Timeout in seconds for the upstream connection attempt. 0 means no timeout.
47    pub upstream_connect_timeout_secs: u64,
48}
49
50impl ProxySessionConfig {
51    /// Returns the ALPN protocol identifiers for the upstream connection.
52    ///
53    /// For QUIC upstreams, mirrors the negotiated client ALPN so we connect
54    /// to the relay with the same protocol the client is speaking. Falls
55    /// back to `self.draft.quic_alpn()` if the client ALPN is empty
56    /// (e.g., the listener didn't capture it).
57    pub fn upstream_alpn(&self, client_alpn: &[u8]) -> Vec<Vec<u8>> {
58        match &self.upstream_transport {
59            UpstreamTransportType::Quic => {
60                if client_alpn.is_empty() {
61                    vec![self.draft.quic_alpn().to_vec()]
62                } else {
63                    vec![client_alpn.to_vec()]
64                }
65            }
66            UpstreamTransportType::WebTransport { .. } => vec![b"h3".to_vec()],
67        }
68    }
69}
70
71impl Default for ProxySessionConfig {
72    fn default() -> Self {
73        Self {
74            draft: DraftVersion::Draft14,
75            upstream_transport: UpstreamTransportType::Quic,
76            upstream_addr: String::new(),
77            skip_upstream_cert_verify: false,
78            upstream_ca_certs: Vec::new(),
79            upstream_connect_timeout_secs: 0,
80        }
81    }
82}
83
84/// A proxy session that forwards traffic between a client and an upstream
85/// relay. One session is created per accepted client connection.
86pub struct ProxySession {
87    session_id: SessionId,
88    config: ProxySessionConfig,
89    /// The ALPN the client negotiated with us (empty for WebTransport or
90    /// when unavailable). Drives both upstream ALPN selection and initial
91    /// draft detection for drafts 15+.
92    client_alpn: Vec<u8>,
93    observer: Arc<dyn ProxyObserver>,
94    hook: Arc<dyn ProxyHook>,
95    cancel: CancellationToken,
96}
97
98impl ProxySession {
99    /// Create a new proxy session.
100    ///
101    /// `client_alpn` should be the ALPN the listener negotiated with the
102    /// client. Pass an empty slice if unavailable (e.g., WebTransport).
103    pub fn new(
104        session_id: SessionId,
105        config: ProxySessionConfig,
106        client_alpn: Vec<u8>,
107        observer: Arc<dyn ProxyObserver>,
108        hook: Arc<dyn ProxyHook>,
109        cancel: CancellationToken,
110    ) -> Self {
111        Self { session_id, config, client_alpn, observer, hook, cancel }
112    }
113
114    /// Run the proxy session with a raw QUIC client connection.
115    pub async fn run(self, client_conn: quinn::Connection) -> Result<(), ProxyError> {
116        let client = Transport::Quic(QuicTransport::new(client_conn));
117        self.run_with_transport(client).await
118    }
119
120    /// Run the proxy session with a WebTransport client connection.
121    #[cfg(feature = "webtransport")]
122    pub async fn run_webtransport(
123        self,
124        client_conn: wtransport::Connection,
125    ) -> Result<(), ProxyError> {
126        use moqtap_client::transport::webtransport::WebTransportTransport;
127        let client = Transport::WebTransport(WebTransportTransport::new(client_conn));
128        self.run_with_transport(client).await
129    }
130
131    /// The initial draft hint for this session. Drafts 15+ resolve
132    /// unambiguously from the client ALPN (`moqt-15` / `moqt-16` /
133    /// `moqt-17`); otherwise we fall back to `config.draft`, which the
134    /// control-stream parser may refine once it peeks at CLIENT_SETUP /
135    /// SERVER_SETUP for the moq-00 cohort (drafts 07–14).
136    fn initial_draft(&self) -> DraftVersion {
137        DraftVersion::from_alpn(&self.client_alpn).unwrap_or(self.config.draft)
138    }
139
140    /// Whether the initial draft is fixed (ALPN-derived) or should be
141    /// refined from CLIENT_SETUP / SERVER_SETUP peek.
142    fn draft_is_fixed(&self) -> bool {
143        DraftVersion::from_alpn(&self.client_alpn).is_some()
144    }
145
146    /// Run the proxy session with an already-wrapped transport.
147    ///
148    /// Connects to the upstream relay, then forwards all streams and
149    /// datagrams bidirectionally between the client and relay. Parses
150    /// MoQT frames inline and emits events via the observer.
151    async fn run_with_transport(self, client: Transport) -> Result<(), ProxyError> {
152        // Connect to upstream relay
153        let relay = self.connect_upstream().await?;
154
155        let client = Arc::new(client);
156        let relay = Arc::new(relay);
157
158        let mut tasks: JoinSet<Result<(), ProxyError>> = JoinSet::new();
159
160        let initial_draft = self.initial_draft();
161        let draft_is_fixed = self.draft_is_fixed();
162
163        let observer_enabled = self.observer.wants_events();
164        let control_mutation = self.hook.wants_control_mutation();
165        let base_ctx = ForwardCtx {
166            session_id: self.session_id,
167            draft: initial_draft,
168            draft_is_fixed,
169            observer: Arc::clone(&self.observer),
170            hook: Arc::clone(&self.hook),
171            cancel: self.cancel.clone(),
172            observer_enabled,
173            control_mutation,
174        };
175
176        // Control stream: accept bi from client, open bi to relay
177        {
178            let client = Arc::clone(&client);
179            let relay = Arc::clone(&relay);
180            let ctx = base_ctx.clone();
181            tasks.spawn(async move { forward_control_stream(&client, &relay, &ctx).await });
182        }
183
184        // Client → Relay uni streams
185        {
186            let client = Arc::clone(&client);
187            let relay = Arc::clone(&relay);
188            let ctx = base_ctx.clone();
189            tasks.spawn(async move {
190                forward_uni_streams(&client, &relay, ProxySide::ClientToProxy, &ctx).await
191            });
192        }
193
194        // Relay → Client uni streams
195        {
196            let client = Arc::clone(&client);
197            let relay = Arc::clone(&relay);
198            let ctx = base_ctx.clone();
199            tasks.spawn(async move {
200                forward_uni_streams(&relay, &client, ProxySide::RelayToProxy, &ctx).await
201            });
202        }
203
204        // Datagram forwarding: client → relay
205        {
206            let client = Arc::clone(&client);
207            let relay = Arc::clone(&relay);
208            let ctx = base_ctx.clone();
209            tasks.spawn(async move {
210                forward_datagrams(&client, &relay, ProxySide::ClientToProxy, &ctx).await
211            });
212        }
213
214        // Datagram forwarding: relay → client
215        {
216            let client = Arc::clone(&client);
217            let relay = Arc::clone(&relay);
218            let ctx = base_ctx.clone();
219            tasks.spawn(async move {
220                forward_datagrams(&relay, &client, ProxySide::RelayToProxy, &ctx).await
221            });
222        }
223
224        // Wait for first task to finish (signals session is done)
225        let first_result = tasks.join_next().await;
226
227        // Cancel remaining tasks
228        self.cancel.cancel();
229        tasks.shutdown().await;
230
231        // Determine reason from first result
232        let reason = match &first_result {
233            Some(Ok(Ok(()))) => "completed".to_string(),
234            Some(Ok(Err(e))) => format!("{e}"),
235            Some(Err(e)) => format!("task panic: {e}"),
236            None => "no tasks".to_string(),
237        };
238        if self.observer.wants_events() {
239            self.observer
240                .on_event(&ProxyEvent::SessionEnded { session_id: self.session_id, reason });
241        }
242
243        // Close both sides
244        client.close(0, b"proxy session ended");
245        relay.close(0, b"proxy session ended");
246
247        match first_result {
248            Some(Ok(Ok(()))) | None => Ok(()),
249            Some(Ok(Err(e))) => Err(e),
250            Some(Err(e)) => Err(ProxyError::SessionClosed(format!("task panic: {e}"))),
251        }
252    }
253
254    /// Connect to the upstream relay (with optional timeout).
255    async fn connect_upstream(&self) -> Result<Transport, ProxyError> {
256        let timeout_secs = self.config.upstream_connect_timeout_secs;
257        if timeout_secs > 0 {
258            tokio::time::timeout(
259                std::time::Duration::from_secs(timeout_secs),
260                self.connect_upstream_inner(),
261            )
262            .await
263            .map_err(|_| {
264                ProxyError::UpstreamConnect(format!("connection timed out after {timeout_secs}s"))
265            })?
266        } else {
267            self.connect_upstream_inner().await
268        }
269    }
270
271    async fn connect_upstream_inner(&self) -> Result<Transport, ProxyError> {
272        match &self.config.upstream_transport {
273            UpstreamTransportType::Quic => self.connect_upstream_quic().await,
274            #[cfg(feature = "webtransport")]
275            UpstreamTransportType::WebTransport { url } => {
276                let url = url.clone();
277                self.connect_upstream_webtransport(&url).await
278            }
279            #[cfg(not(feature = "webtransport"))]
280            UpstreamTransportType::WebTransport { .. } => {
281                Err(ProxyError::UpstreamConnect("webtransport feature not enabled".to_string()))
282            }
283        }
284    }
285
286    /// Connect to the upstream relay via QUIC.
287    async fn connect_upstream_quic(&self) -> Result<Transport, ProxyError> {
288        let server_addr =
289            self.config.upstream_addr.parse().map_err(|e: std::net::AddrParseError| {
290                ProxyError::UpstreamConnect(e.to_string())
291            })?;
292
293        let mut tls_config = self.build_upstream_tls_config()?;
294        tls_config.alpn_protocols = self.config.upstream_alpn(&self.client_alpn);
295
296        let quic_config: quinn::crypto::rustls::QuicClientConfig =
297            tls_config.try_into().map_err(|e| ProxyError::TlsConfig(format!("{e}")))?;
298        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
299
300        let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
301            .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
302        endpoint.set_default_client_config(client_config);
303
304        let server_name =
305            self.config.upstream_addr.split(':').next().unwrap_or("localhost").to_string();
306
307        let conn = endpoint
308            .connect(server_addr, &server_name)
309            .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?
310            .await
311            .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
312
313        Ok(Transport::Quic(QuicTransport::new(conn)))
314    }
315
316    /// Connect to the upstream relay via WebTransport.
317    #[cfg(feature = "webtransport")]
318    async fn connect_upstream_webtransport(&self, url: &str) -> Result<Transport, ProxyError> {
319        use moqtap_client::transport::webtransport::WebTransportTransport;
320
321        let wt_config = if self.config.skip_upstream_cert_verify {
322            wtransport::ClientConfig::builder()
323                .with_bind_default()
324                .with_no_cert_validation()
325                .build()
326        } else {
327            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
328        };
329
330        let endpoint = wtransport::Endpoint::client(wt_config)
331            .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
332
333        let connection =
334            endpoint.connect(url).await.map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
335
336        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
337    }
338
339    /// Build a rustls `ClientConfig` for the upstream connection.
340    fn build_upstream_tls_config(&self) -> Result<rustls::ClientConfig, ProxyError> {
341        if self.config.skip_upstream_cert_verify {
342            Ok(rustls::ClientConfig::builder()
343                .dangerous()
344                .with_custom_certificate_verifier(Arc::new(SkipVerification))
345                .with_no_client_auth())
346        } else {
347            let mut roots = rustls::RootCertStore::empty();
348            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
349            for der in &self.config.upstream_ca_certs {
350                roots
351                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
352                    .map_err(|e| ProxyError::TlsConfig(format!("bad CA cert: {e}")))?;
353            }
354            Ok(rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth())
355        }
356    }
357}
358
359// ── Forwarding helpers ──────────────────────────────────────────
360
361/// Shared context for forwarding helpers, avoiding repeated parameter lists.
362#[derive(Clone)]
363struct ForwardCtx {
364    session_id: SessionId,
365    /// The current best draft guess for this session. For control streams
366    /// this may be refined after observing CLIENT_SETUP / SERVER_SETUP when
367    /// `draft_is_fixed` is false.
368    draft: DraftVersion,
369    /// Whether `draft` is fixed (from ALPN) and should not be refined by
370    /// peeking at SETUP messages.
371    draft_is_fixed: bool,
372    observer: Arc<dyn ProxyObserver>,
373    hook: Arc<dyn ProxyHook>,
374    cancel: CancellationToken,
375    /// Cached `observer.wants_events()` — gates event construction and
376    /// emission in the hot forwarding loop. When `false`, the proxy can
377    /// skip parsing for observation purposes and run as a byte pump.
378    observer_enabled: bool,
379    /// Cached `hook.wants_control_mutation()` — when `true`, the control
380    /// stream forwarder switches to a parse-then-forward mode that honors
381    /// the hook's `Some(bytes)` return. Defaults to `false` (pure
382    /// pass-through) to avoid per-frame parsing latency.
383    control_mutation: bool,
384}
385
386impl ForwardCtx {
387    /// Emit a proxy event only if the observer wants events.
388    ///
389    /// Takes a closure so the `ProxyEvent` is not constructed when
390    /// observation is disabled — avoiding clones of message payloads in
391    /// the hot path.
392    fn emit(&self, event: impl FnOnce() -> ProxyEvent) {
393        if self.observer_enabled {
394            self.observer.on_event(&event());
395        }
396    }
397}
398
399/// Forward the control stream (first bidirectional stream).
400async fn forward_control_stream(
401    client: &Transport,
402    relay: &Transport,
403    ctx: &ForwardCtx,
404) -> Result<(), ProxyError> {
405    // Accept bi from client
406    let (client_send, client_recv) = client.accept_bi().await?;
407    ctx.emit(|| ProxyEvent::BiStreamOpened {
408        session_id: ctx.session_id,
409        side: ProxySide::ClientToProxy,
410    });
411
412    // Open bi to relay
413    let (relay_send, relay_recv) = relay.open_bi().await?;
414    ctx.emit(|| ProxyEvent::BiStreamOpened {
415        session_id: ctx.session_id,
416        side: ProxySide::ProxyToRelay,
417    });
418
419    // Pipe client→relay and relay→client concurrently
420    let ctx1 = ForwardCtx { ..ctx.clone() };
421    let ctx2 = ForwardCtx { ..ctx.clone() };
422
423    let client_to_relay = tokio::spawn(async move {
424        pipe_control(client_recv, relay_send, ProxySide::ClientToProxy, &ctx1).await
425    });
426
427    let relay_to_client = tokio::spawn(async move {
428        pipe_control(relay_recv, client_send, ProxySide::RelayToProxy, &ctx2).await
429    });
430
431    tokio::select! {
432        r = client_to_relay => r.map_err(|e| ProxyError::SessionClosed(e.to_string()))?,
433        r = relay_to_client => r.map_err(|e| ProxyError::SessionClosed(e.to_string()))?,
434        _ = ctx.cancel.cancelled() => Ok(()),
435    }
436}
437
438/// Pipe a control stream direction.
439///
440/// Bytes are forwarded to the peer immediately upon receipt — the parser
441/// runs on a cloned copy purely to emit observer events. A stuck or
442/// erroring parser can never block forwarding. This matches the
443/// pass-through semantics of the data-stream and datagram paths.
444///
445/// If `ctx.draft_is_fixed` is false (moq-00 cohort, drafts 07–14), the
446/// parser start is deferred until enough bytes arrive to peek the first
447/// SETUP message and pick a concrete draft. Bytes observed during that
448/// detection window are still forwarded immediately.
449async fn pipe_control(
450    recv: RecvStream,
451    send: SendStream,
452    side: ProxySide,
453    ctx: &ForwardCtx,
454) -> Result<(), ProxyError> {
455    if ctx.control_mutation {
456        pipe_control_mutating(recv, send, side, ctx).await
457    } else {
458        pipe_control_passthrough(recv, send, side, ctx).await
459    }
460}
461
462/// Forward-first control stream pipe.
463///
464/// Bytes are forwarded to the peer the instant they arrive; the parser
465/// runs on a cloned copy purely to drive observer events. No hook can
466/// rewrite frames on this path because the bytes are already in flight.
467async fn pipe_control_passthrough(
468    mut recv: RecvStream,
469    mut send: SendStream,
470    side: ProxySide,
471    ctx: &ForwardCtx,
472) -> Result<(), ProxyError> {
473    let mut buf = [0u8; 8192];
474
475    let mut parser: Option<ControlStreamParser> =
476        if ctx.draft_is_fixed { Some(ControlStreamParser::new(ctx.draft)) } else { None };
477
478    const DETECT_BUF_MAX: usize = 64 * 1024;
479    let mut detect_buf = BytesMut::new();
480
481    loop {
482        tokio::select! {
483            result = recv.read(&mut buf) => {
484                match result? {
485                    Some(n) => {
486                        let data = &buf[..n];
487
488                        // ── Forward immediately — no gating on parse ────
489                        send.write_all(data).await?;
490
491                        // ── Observer-only parse (side path) ─────────────
492                        // Skip parsing when nobody is observing: the proxy
493                        // becomes a pure byte pump on the control stream.
494                        if ctx.observer_enabled {
495                            if let Some(ref mut p) = parser {
496                                emit_parsed_frames(p, data, side, ctx);
497                            } else {
498                                detect_buf.extend_from_slice(data);
499                                if let Some(detected) = detect_draft_from_setup(
500                                    &detect_buf,
501                                    side,
502                                    ctx.draft,
503                                ) {
504                                    let mut p = ControlStreamParser::new(detected);
505                                    let buffered = detect_buf.split().freeze();
506                                    emit_parsed_frames(&mut p, &buffered, side, ctx);
507                                    parser = Some(p);
508                                } else if detect_buf.len() >= DETECT_BUF_MAX {
509                                    ctx.emit(|| ProxyEvent::ParseError {
510                                        session_id: ctx.session_id,
511                                        side,
512                                        error: format!(
513                                            "control draft detection gave up after {} bytes; \
514                                             falling back to {}",
515                                            detect_buf.len(),
516                                            ctx.draft,
517                                        ),
518                                    });
519                                    parser = Some(ControlStreamParser::new(ctx.draft));
520                                    detect_buf = BytesMut::new();
521                                }
522                            }
523                        }
524                    }
525                    None => {
526                        ctx.emit(|| ProxyEvent::StreamClosed {
527                            session_id: ctx.session_id,
528                            side,
529                        });
530                        let _ = send.finish();
531                        return Ok(());
532                    }
533                }
534            }
535            _ = ctx.cancel.cancelled() => {
536                return Ok(());
537            }
538        }
539    }
540}
541
542/// Parse-then-forward control stream pipe.
543///
544/// Bytes are withheld until a complete control message has been parsed,
545/// at which point the hook's `on_control_message` is consulted and its
546/// `Some(bytes)` return value (if any) is forwarded in place of the
547/// original wire frame. This adds a per-frame latency cost — a hook that
548/// only observes should leave `wants_control_mutation()` at its default
549/// `false` and take the pass-through path instead.
550async fn pipe_control_mutating(
551    mut recv: RecvStream,
552    mut send: SendStream,
553    side: ProxySide,
554    ctx: &ForwardCtx,
555) -> Result<(), ProxyError> {
556    let mut buf = [0u8; 8192];
557
558    // Capturing parser — we need the original raw bytes so the hook can
559    // choose to pass them through unchanged.
560    let mut parser: Option<ControlStreamParser> =
561        if ctx.draft_is_fixed { Some(ControlStreamParser::new_capturing(ctx.draft)) } else { None };
562
563    const DETECT_BUF_MAX: usize = 64 * 1024;
564    let mut detect_buf = BytesMut::new();
565
566    loop {
567        tokio::select! {
568            result = recv.read(&mut buf) => {
569                match result? {
570                    Some(n) => {
571                        let data = &buf[..n];
572
573                        match parser.as_mut() {
574                            Some(p) => {
575                                forward_mutated_frames(p, data, &mut send, side, ctx).await?;
576                            }
577                            None => {
578                                detect_buf.extend_from_slice(data);
579                                let new_parser = if let Some(detected) = detect_draft_from_setup(
580                                    &detect_buf,
581                                    side,
582                                    ctx.draft,
583                                ) {
584                                    Some(ControlStreamParser::new_capturing(detected))
585                                } else if detect_buf.len() >= DETECT_BUF_MAX {
586                                    ctx.emit(|| ProxyEvent::ParseError {
587                                        session_id: ctx.session_id,
588                                        side,
589                                        error: format!(
590                                            "control draft detection gave up after {} bytes; \
591                                             falling back to {}",
592                                            detect_buf.len(),
593                                            ctx.draft,
594                                        ),
595                                    });
596                                    Some(ControlStreamParser::new_capturing(ctx.draft))
597                                } else {
598                                    // Still detecting; nothing to forward yet.
599                                    None
600                                };
601
602                                if let Some(mut p) = new_parser {
603                                    let buffered = detect_buf.split().freeze();
604                                    forward_mutated_frames(
605                                        &mut p,
606                                        &buffered,
607                                        &mut send,
608                                        side,
609                                        ctx,
610                                    )
611                                    .await?;
612                                    parser = Some(p);
613                                }
614                            }
615                        }
616                    }
617                    None => {
618                        ctx.emit(|| ProxyEvent::StreamClosed {
619                            session_id: ctx.session_id,
620                            side,
621                        });
622                        let _ = send.finish();
623                        return Ok(());
624                    }
625                }
626            }
627            _ = ctx.cancel.cancelled() => {
628                return Ok(());
629            }
630        }
631    }
632}
633
634/// Feed bytes into the capturing control parser, then forward each
635/// completed frame — giving the hook the chance to rewrite it.
636async fn forward_mutated_frames(
637    parser: &mut ControlStreamParser,
638    data: &[u8],
639    send: &mut SendStream,
640    side: ProxySide,
641    ctx: &ForwardCtx,
642) -> Result<(), ProxyError> {
643    let result = parser.feed(data);
644    if let ParseResult::Messages(frames) = result {
645        for frame in frames {
646            let raw = frame.raw_bytes.as_deref().expect("capturing parser must populate raw_bytes");
647            let replacement =
648                ctx.hook.on_control_message(ctx.session_id, side, &frame.message, raw);
649            let out: &[u8] = match replacement.as_deref() {
650                Some(bytes) => bytes,
651                None => raw,
652            };
653            send.write_all(out).await?;
654
655            if ctx.observer_enabled {
656                let event = if frame.message.is_setup() {
657                    ProxyEvent::SetupMessage {
658                        session_id: ctx.session_id,
659                        side,
660                        message: frame.message.clone(),
661                    }
662                } else {
663                    ProxyEvent::ControlMessage {
664                        session_id: ctx.session_id,
665                        side,
666                        message: frame.message.clone(),
667                    }
668                };
669                ctx.observer.on_event(&event);
670            }
671        }
672    }
673    Ok(())
674}
675
676/// Feed bytes to the control parser and emit observer events for any
677/// completed frames.
678///
679/// The hook is deliberately not invoked here — on the pass-through path
680/// the bytes have already been forwarded, so a mutating `Some(bytes)`
681/// return would be silently discarded. Hooks that need to see control
682/// messages without rewriting them should be implemented as a
683/// [`ProxyObserver`]; hooks that need to rewrite them should set
684/// `wants_control_mutation()` to `true`, which routes traffic through
685/// `pipe_control_mutating` instead.
686fn emit_parsed_frames(
687    parser: &mut ControlStreamParser,
688    data: &[u8],
689    side: ProxySide,
690    ctx: &ForwardCtx,
691) {
692    if !ctx.observer_enabled {
693        return;
694    }
695    match parser.feed(data) {
696        ParseResult::Messages(frames) => {
697            for frame in &frames {
698                let event = if frame.message.is_setup() {
699                    ProxyEvent::SetupMessage {
700                        session_id: ctx.session_id,
701                        side,
702                        message: frame.message.clone(),
703                    }
704                } else {
705                    ProxyEvent::ControlMessage {
706                        session_id: ctx.session_id,
707                        side,
708                        message: frame.message.clone(),
709                    }
710                };
711                ctx.observer.on_event(&event);
712            }
713        }
714        ParseResult::NeedMore => {}
715    }
716}
717
718/// Forward unidirectional streams from source to destination.
719async fn forward_uni_streams(
720    source: &Transport,
721    dest: &Transport,
722    side: ProxySide,
723    ctx: &ForwardCtx,
724) -> Result<(), ProxyError> {
725    loop {
726        tokio::select! {
727            result = source.accept_uni() => {
728                let recv = result?;
729                ctx.emit(|| ProxyEvent::UniStreamOpened {
730                    session_id: ctx.session_id,
731                    side,
732                });
733
734                let send = dest.open_uni().await?;
735                let ctx = ctx.clone();
736
737                tokio::spawn(async move {
738                    if let Err(e) = pipe_data(
739                        recv, send, side, &ctx,
740                    )
741                    .await
742                    {
743                        ctx.emit(|| ProxyEvent::ParseError {
744                            session_id: ctx.session_id,
745                            side,
746                            error: format!("uni stream pipe: {e}"),
747                        });
748                    }
749                });
750            }
751            _ = ctx.cancel.cancelled() => {
752                return Ok(());
753            }
754        }
755    }
756}
757
758/// Determine the data stream type from the first varint on the stream.
759///
760/// MoQT data streams start with a stream type varint:
761/// - 0x04 = Subgroup
762/// - 0x05 = Fetch
763///
764/// Returns `(stream_type, bytes_consumed)` so the caller can account for
765/// the type varint when setting up the parser.
766fn detect_stream_type(first_byte: u8) -> DataStreamType {
767    // The stream type varint is a single byte for values < 64.
768    // Subgroup = 0x04, Fetch = 0x05.
769    match first_byte {
770        0x05 => DataStreamType::Fetch,
771        // Default to Subgroup for 0x04 and anything else
772        _ => DataStreamType::Subgroup,
773    }
774}
775
776/// Pipe a unidirectional data stream with inline parsing.
777async fn pipe_data(
778    mut recv: RecvStream,
779    mut send: SendStream,
780    side: ProxySide,
781    ctx: &ForwardCtx,
782) -> Result<(), ProxyError> {
783    let mut buf = [0u8; 8192];
784    let mut parser: Option<DataStreamParser> = None;
785
786    loop {
787        tokio::select! {
788            result = recv.read(&mut buf) => {
789                match result? {
790                    Some(n) => {
791                        let data = &buf[..n];
792
793                        // Skip parsing entirely when nobody is observing —
794                        // the proxy then runs as a straight byte pump.
795                        if ctx.observer_enabled {
796                            // On first data, detect stream type from the
797                            // leading varint and create the parser. Skip the
798                            // stream type varint when feeding to the parser
799                            // since SubgroupHeader/FetchHeader::decode()
800                            // expects it to be already consumed.
801                            let feed_data = if parser.is_none() && !data.is_empty() {
802                                let stream_type = detect_stream_type(data[0]);
803                                parser = Some(DataStreamParser::new(
804                                    stream_type, ctx.draft,
805                                ));
806                                // The stream type varint is a single byte
807                                // for values 0x00..0x3F. Skip it.
808                                let skip = varint_len(data[0]);
809                                &data[skip.min(data.len())..]
810                            } else {
811                                data
812                            };
813
814                            if let Some(ref mut p) = parser {
815                                let results = p.feed(feed_data);
816                                for result in &results {
817                                    match result {
818                                        crate::parser::data::DataParseResult::Header(header) => {
819                                            ctx.observer.on_event(
820                                                &ProxyEvent::DataStreamHeader {
821                                                    session_id: ctx.session_id,
822                                                    side,
823                                                    header: header.clone(),
824                                                },
825                                            );
826                                        }
827                                        crate::parser::data::DataParseResult::Object(header) => {
828                                            ctx.observer.on_event(
829                                                &ProxyEvent::ObjectHeader {
830                                                    session_id: ctx.session_id,
831                                                    side,
832                                                    header: header.clone(),
833                                                },
834                                            );
835                                        }
836                                        crate::parser::data::DataParseResult::Error(e) => {
837                                            ctx.observer.on_event(
838                                                &ProxyEvent::ParseError {
839                                                    session_id: ctx.session_id,
840                                                    side,
841                                                    error: e.clone(),
842                                                },
843                                            );
844                                        }
845                                        crate::parser::data::DataParseResult::NeedMore => {}
846                                    }
847                                }
848                            }
849                        }
850
851                        // Always forward the raw bytes (including
852                        // stream type varint)
853                        send.write_all(data).await?;
854                    }
855                    None => {
856                        ctx.emit(|| ProxyEvent::StreamClosed {
857                            session_id: ctx.session_id,
858                            side,
859                        });
860                        let _ = send.finish();
861                        return Ok(());
862                    }
863                }
864            }
865            _ = ctx.cancel.cancelled() => {
866                return Ok(());
867            }
868        }
869    }
870}
871
872/// Forward datagrams from source to destination.
873async fn forward_datagrams(
874    source: &Transport,
875    dest: &Transport,
876    side: ProxySide,
877    ctx: &ForwardCtx,
878) -> Result<(), ProxyError> {
879    loop {
880        tokio::select! {
881            result = source.recv_datagram() => {
882                let data = result?;
883
884                // Parse only when someone cares about events or the hook
885                // may want to mutate the datagram.
886                if ctx.observer_enabled {
887                    let mut cursor = &data[..];
888                    if let Ok(header) = AnyDatagramHeader::decode(ctx.draft, &mut cursor) {
889                        let payload_len = cursor.len();
890                        ctx.observer.on_event(&ProxyEvent::Datagram {
891                            session_id: ctx.session_id,
892                            side,
893                            header: header.clone(),
894                            payload_len,
895                        });
896
897                        if let Some(replacement) = ctx.hook.on_datagram(
898                            ctx.session_id, side, &header, &data,
899                        ) {
900                            dest.send_datagram(Bytes::from(replacement))?;
901                        } else {
902                            dest.send_datagram(data)?;
903                        }
904                    } else {
905                        dest.send_datagram(data)?;
906                    }
907                } else {
908                    dest.send_datagram(data)?;
909                }
910            }
911            _ = ctx.cancel.cancelled() => {
912                return Ok(());
913            }
914        }
915    }
916}
917
918/// Determine the encoded length of a QUIC varint from its first byte.
919fn varint_len(first_byte: u8) -> usize {
920    1 << (first_byte >> 6)
921}
922
923/// Try to detect the concrete draft version by peeking at the first SETUP
924/// message on a control stream.
925///
926/// - On the `ClientToProxy` direction, looks at CLIENT_SETUP's
927///   `supported_versions` list and returns the highest draft in the 07–14
928///   range we support.
929/// - On the `RelayToProxy` direction, looks at SERVER_SETUP's
930///   `selected_version` and returns the matching draft.
931/// - For draft-15+ the SETUP carries no version, but those cases don't
932///   reach this function because the caller only invokes it when the
933///   draft isn't already fixed by ALPN.
934///
935/// Returns `None` if the buffer doesn't yet contain enough bytes for a
936/// decision, or if the first message isn't a SETUP we recognize. The
937/// caller keeps buffering and retries.
938///
939/// `fallback` is the session's configured default draft, used only to
940/// reject impossible answers (e.g., SERVER_SETUP selected_version outside
941/// the supported range).
942fn detect_draft_from_setup(
943    buf: &[u8],
944    side: ProxySide,
945    fallback: DraftVersion,
946) -> Option<DraftVersion> {
947    let _ = fallback;
948    if buf.is_empty() {
949        return None;
950    }
951
952    // Decode the message type varint. The first byte's top two bits give
953    // the varint length. For drafts 07–10 the type is 0x40/0x41, encoded
954    // as a 2-byte varint. For drafts 11+ it's 0x20/0x21, a 1-byte varint.
955    let type_len = varint_len(buf[0]);
956    if buf.len() < type_len {
957        return None;
958    }
959    let mut cur = &buf[..type_len];
960    let type_id = VarInt::decode(&mut cur).ok()?.into_inner();
961
962    // Distinguish framing by the type id:
963    //   0x40 = CLIENT_SETUP (drafts 07–10, varint length)
964    //   0x41 = SERVER_SETUP (drafts 07–10, varint length)
965    //   0x20 = CLIENT_SETUP (drafts 11+, u16-BE length)
966    //   0x21 = SERVER_SETUP (drafts 11+, u16-BE length)
967    let (is_client_setup, is_server_setup, uses_u16_length) = match type_id {
968        0x40 => (true, false, false),
969        0x41 => (false, true, false),
970        0x20 => (true, false, true),
971        0x21 => (false, true, true),
972        _ => return None,
973    };
974
975    // The message we peek at is the one we'd expect to see first on this
976    // direction. Anything else is probably out-of-order bytes we can't
977    // disambiguate.
978    match side {
979        ProxySide::ClientToProxy | ProxySide::ProxyToRelay if !is_client_setup => return None,
980        ProxySide::RelayToProxy | ProxySide::ProxyToClient if !is_server_setup => return None,
981        _ => {}
982    }
983
984    let (payload_start, payload_len) = if uses_u16_length {
985        if buf.len() < type_len + 2 {
986            return None;
987        }
988        let len = ((buf[type_len] as usize) << 8) | (buf[type_len + 1] as usize);
989        (type_len + 2, len)
990    } else {
991        if buf.len() <= type_len {
992            return None;
993        }
994        let vl = varint_len(buf[type_len]);
995        if buf.len() < type_len + vl {
996            return None;
997        }
998        let mut cur = &buf[type_len..type_len + vl];
999        let v = VarInt::decode(&mut cur).ok()?;
1000        (type_len + vl, v.into_inner() as usize)
1001    };
1002
1003    if buf.len() < payload_start + payload_len {
1004        return None;
1005    }
1006    let payload = &buf[payload_start..payload_start + payload_len];
1007
1008    if is_client_setup {
1009        // CLIENT_SETUP (draft 07–14): number_of_supported_versions (varint)
1010        // then that many version varints. Pick the highest draft we
1011        // support in the moq-00 cohort (07–14).
1012        let mut cur = payload;
1013        let count = VarInt::decode(&mut cur).ok()?.into_inner() as usize;
1014        let mut best: Option<DraftVersion> = None;
1015        for _ in 0..count {
1016            let v = VarInt::decode(&mut cur).ok()?.into_inner();
1017            if let Some(d) = version_varint_to_draft(v) {
1018                if (7..=14).contains(&d.number()) {
1019                    best = Some(match best {
1020                        Some(b) if b.number() >= d.number() => b,
1021                        _ => d,
1022                    });
1023                }
1024            }
1025        }
1026        best
1027    } else {
1028        // SERVER_SETUP (draft 07–14): selected_version (varint) then
1029        // parameters. We only need the first varint.
1030        let mut cur = payload;
1031        let v = VarInt::decode(&mut cur).ok()?.into_inner();
1032        let d = version_varint_to_draft(v)?;
1033        if (7..=14).contains(&d.number()) {
1034            Some(d)
1035        } else {
1036            None
1037        }
1038    }
1039}
1040
1041/// Convert an on-wire MoQT version varint (`0xff000000 + draft`) to a
1042/// `DraftVersion`, or `None` if the value is malformed or unsupported.
1043fn version_varint_to_draft(v: u64) -> Option<DraftVersion> {
1044    const BASE: u64 = 0xff000000;
1045    if !(BASE..=BASE + 255).contains(&v) {
1046        return None;
1047    }
1048    DraftVersion::from_number((v - BASE) as u8)
1049}
1050
1051/// TLS certificate verifier that skips all verification (testing only).
1052#[derive(Debug)]
1053struct SkipVerification;
1054
1055impl rustls::client::danger::ServerCertVerifier for SkipVerification {
1056    fn verify_server_cert(
1057        &self,
1058        _end_entity: &rustls::pki_types::CertificateDer<'_>,
1059        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
1060        _server_name: &rustls::pki_types::ServerName<'_>,
1061        _ocsp_response: &[u8],
1062        _now: rustls::pki_types::UnixTime,
1063    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
1064        Ok(rustls::client::danger::ServerCertVerified::assertion())
1065    }
1066
1067    fn verify_tls12_signature(
1068        &self,
1069        _message: &[u8],
1070        _cert: &rustls::pki_types::CertificateDer<'_>,
1071        _dcs: &rustls::DigitallySignedStruct,
1072    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1073        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1074    }
1075
1076    fn verify_tls13_signature(
1077        &self,
1078        _message: &[u8],
1079        _cert: &rustls::pki_types::CertificateDer<'_>,
1080        _dcs: &rustls::DigitallySignedStruct,
1081    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1082        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1083    }
1084
1085    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
1086        vec![
1087            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
1088            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
1089            rustls::SignatureScheme::ED25519,
1090            rustls::SignatureScheme::RSA_PSS_SHA256,
1091            rustls::SignatureScheme::RSA_PSS_SHA384,
1092            rustls::SignatureScheme::RSA_PSS_SHA512,
1093        ]
1094    }
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099    use super::*;
1100
1101    /// Build a draft-07 CLIENT_SETUP on the wire with the given supported
1102    /// draft numbers.
1103    fn encode_client_setup_d07(drafts: &[u8]) -> Vec<u8> {
1104        use moqtap_codec::draft07::message::{ClientSetup, ControlMessage};
1105        let mut supported = Vec::new();
1106        for &n in drafts {
1107            supported.push(VarInt::from_usize(0xff000000 + n as usize));
1108        }
1109        let msg = ControlMessage::ClientSetup(ClientSetup {
1110            supported_versions: supported,
1111            parameters: Vec::new(),
1112        });
1113        let mut buf = Vec::new();
1114        msg.encode(&mut buf).unwrap();
1115        buf
1116    }
1117
1118    /// Build a draft-14 CLIENT_SETUP on the wire (u16-BE framing).
1119    fn encode_client_setup_d14(drafts: &[u8]) -> Vec<u8> {
1120        use moqtap_codec::draft14::message::{ClientSetup, ControlMessage};
1121        let mut supported = Vec::new();
1122        for &n in drafts {
1123            supported.push(VarInt::from_usize(0xff000000 + n as usize));
1124        }
1125        let msg = ControlMessage::ClientSetup(ClientSetup {
1126            supported_versions: supported,
1127            parameters: Vec::new(),
1128        });
1129        let mut buf = Vec::new();
1130        msg.encode(&mut buf).unwrap();
1131        buf
1132    }
1133
1134    /// Build a draft-07 SERVER_SETUP on the wire.
1135    fn encode_server_setup_d07(draft: u8) -> Vec<u8> {
1136        use moqtap_codec::draft07::message::{ControlMessage, ServerSetup};
1137        let msg = ControlMessage::ServerSetup(ServerSetup {
1138            selected_version: VarInt::from_usize(0xff000000 + draft as usize),
1139            parameters: Vec::new(),
1140        });
1141        let mut buf = Vec::new();
1142        msg.encode(&mut buf).unwrap();
1143        buf
1144    }
1145
1146    /// Build a draft-14 SERVER_SETUP on the wire (u16-BE framing).
1147    fn encode_server_setup_d14(draft: u8) -> Vec<u8> {
1148        use moqtap_codec::draft14::message::{ControlMessage, ServerSetup};
1149        let msg = ControlMessage::ServerSetup(ServerSetup {
1150            selected_version: VarInt::from_usize(0xff000000 + draft as usize),
1151            parameters: Vec::new(),
1152        });
1153        let mut buf = Vec::new();
1154        msg.encode(&mut buf).unwrap();
1155        buf
1156    }
1157
1158    #[test]
1159    fn detect_picks_highest_draft_from_07_10_varint_framing() {
1160        // Drafts 07 and 09 offered; expect 09.
1161        let bytes = encode_client_setup_d07(&[7, 9]);
1162        let d = detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14);
1163        assert_eq!(d, Some(DraftVersion::Draft09));
1164    }
1165
1166    #[test]
1167    fn detect_picks_highest_draft_from_11_14_u16_framing() {
1168        // Drafts 11, 13, 14 offered; expect 14.
1169        let bytes = encode_client_setup_d14(&[11, 13, 14]);
1170        let d = detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft11);
1171        assert_eq!(d, Some(DraftVersion::Draft14));
1172    }
1173
1174    #[test]
1175    fn detect_from_server_setup_varint_framing() {
1176        let bytes = encode_server_setup_d07(10);
1177        let d = detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft07);
1178        assert_eq!(d, Some(DraftVersion::Draft10));
1179    }
1180
1181    #[test]
1182    fn detect_from_server_setup_u16_framing() {
1183        let bytes = encode_server_setup_d14(14);
1184        let d = detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft11);
1185        assert_eq!(d, Some(DraftVersion::Draft14));
1186    }
1187
1188    #[test]
1189    fn detect_returns_none_on_partial_bytes() {
1190        let bytes = encode_client_setup_d14(&[14]);
1191        // Truncate to one byte — not enough to read length field.
1192        assert_eq!(
1193            detect_draft_from_setup(&bytes[..1], ProxySide::ClientToProxy, DraftVersion::Draft14),
1194            None
1195        );
1196    }
1197
1198    #[test]
1199    fn detect_returns_none_for_unrelated_first_byte() {
1200        // First byte 0x10 is GOAWAY — not a SETUP we can detect from.
1201        let bytes = [0x10u8, 0x00, 0x00];
1202        assert_eq!(
1203            detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14),
1204            None
1205        );
1206    }
1207
1208    #[test]
1209    fn detect_ignores_15_plus_versions_in_moq_00_setup() {
1210        // A malformed CLIENT_SETUP advertising only draft-15 over moq-00
1211        // (which shouldn't happen in practice). We refuse to pick 15 here
1212        // because 15+ uses ALPN, not CLIENT_SETUP.
1213        let bytes = encode_client_setup_d14(&[15]);
1214        assert_eq!(
1215            detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14),
1216            None
1217        );
1218    }
1219
1220    #[test]
1221    fn detect_setup_wrong_direction_returns_none() {
1222        // CLIENT_SETUP peeked as SERVER_SETUP → None.
1223        let bytes = encode_client_setup_d14(&[14]);
1224        assert_eq!(
1225            detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft14),
1226            None
1227        );
1228    }
1229}