Skip to main content

moqtap_client/draft14/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft14::endpoint::{Endpoint, EndpointError};
6use crate::draft14::event::{ClientEvent, Direction, StreamKind};
7use crate::draft14::observer::ConnectionObserver;
8use crate::draft14::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft14::data_stream::{FetchObject, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft14::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::types::*;
18use moqtap_codec::varint::VarInt;
19use moqtap_codec::version::DraftVersion;
20
21/// MoQT ALPN identifier (used by raw QUIC transport).
22pub const MOQT_ALPN: &[u8] = b"moq-00";
23
24/// Errors from the connection layer.
25#[derive(Debug, thiserror::Error)]
26pub enum ConnectionError {
27    /// Endpoint state machine error.
28    #[error("endpoint error: {0}")]
29    Endpoint(#[from] EndpointError),
30    /// Wire codec error.
31    #[error("codec error: {0}")]
32    Codec(#[from] CodecError),
33    /// Transport-level error.
34    #[error("transport error: {0}")]
35    Transport(#[from] TransportError),
36    /// Variable-length integer decoding error.
37    #[error("varint error: {0}")]
38    VarInt(#[from] moqtap_codec::varint::VarIntError),
39    /// Control stream was not opened.
40    #[error("control stream not open")]
41    NoControlStream,
42    /// Stream ended before a complete message was read.
43    #[error("unexpected end of stream")]
44    UnexpectedEnd,
45    /// Stream was finished by the peer.
46    #[error("stream finished")]
47    StreamFinished,
48    /// Invalid server address string.
49    #[error("invalid server address: {0}")]
50    InvalidAddress(String),
51    /// TLS configuration error.
52    #[error("TLS config error: {0}")]
53    TlsConfig(String),
54    /// Data stream used out of order (e.g. object before header).
55    #[error("data stream state error: {0}")]
56    DataStreamState(&'static str),
57}
58
59/// Transport type for the connection.
60#[derive(Debug, Clone)]
61pub enum TransportType {
62    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
63    Quic,
64    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
65    WebTransport {
66        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
67        url: String,
68    },
69}
70
71/// Configuration for a MoQT client connection.
72///
73/// Both `draft` and `transport` are required — there is no `Default` impl.
74pub struct ClientConfig {
75    /// The MoQT draft version to use (primary, determines codec/framing).
76    pub draft: DraftVersion,
77    /// Additional draft versions to offer in CLIENT_SETUP.
78    /// The primary `draft` is always included first.
79    pub additional_versions: Vec<DraftVersion>,
80    /// The transport type (QUIC or WebTransport).
81    pub transport: TransportType,
82    /// Whether to skip TLS certificate verification (for testing).
83    pub skip_cert_verification: bool,
84    /// Custom CA certificates to trust (DER-encoded).
85    pub ca_certs: Vec<Vec<u8>>,
86    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
87    pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
88}
89
90impl ClientConfig {
91    /// Returns the MoQT version varints for the CLIENT_SETUP message.
92    /// Primary draft first, then any additional versions.
93    pub fn supported_versions(&self) -> Vec<VarInt> {
94        let mut versions = vec![self.draft.version_varint()];
95        for v in &self.additional_versions {
96            let varint = v.version_varint();
97            if !versions.contains(&varint) {
98                versions.push(varint);
99            }
100        }
101        versions
102    }
103
104    /// Returns the ALPN protocol identifiers for the transport.
105    pub fn alpn(&self) -> Vec<Vec<u8>> {
106        match &self.transport {
107            TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
108            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
109        }
110    }
111}
112
113/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
114pub struct FramedSendStream {
115    inner: SendStream,
116    draft: DraftVersion,
117    /// Stateful subgroup object encoder. Initialized by
118    /// [`FramedSendStream::write_subgroup_header`] and used by
119    /// [`FramedSendStream::write_subgroup_object`] to track the delta-encoded
120    /// object ID state.
121    subgroup_io: Option<SubgroupObjectReader>,
122}
123
124impl FramedSendStream {
125    /// Create a new framed send stream for the given draft version.
126    pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
127        Self { inner, draft, subgroup_io: None }
128    }
129
130    /// Get the transport-level stream ID.
131    pub fn stream_id(&self) -> u64 {
132        self.inner.stream_id()
133    }
134
135    /// Write a control message to the stream with type+length framing.
136    /// Returns the raw bytes that were written (for event capture).
137    pub async fn write_control(
138        &mut self,
139        msg: &AnyControlMessage,
140    ) -> Result<Vec<u8>, ConnectionError> {
141        let mut buf = Vec::new();
142        msg.encode(&mut buf)?;
143        self.inner.write_all(&buf).await?;
144        Ok(buf)
145    }
146
147    /// Write a subgroup stream header. Also initializes the internal
148    /// delta-encoding state used by [`FramedSendStream::write_subgroup_object`].
149    pub async fn write_subgroup_header(
150        &mut self,
151        header: &AnySubgroupHeader,
152    ) -> Result<(), ConnectionError> {
153        let mut buf = Vec::new();
154        header.encode(&mut buf);
155        self.inner.write_all(&buf).await?;
156        if let AnySubgroupHeader::Draft14(ref d14) = header {
157            self.subgroup_io = Some(SubgroupObjectReader::new(d14));
158        }
159        Ok(())
160    }
161
162    /// Write a fetch response header.
163    pub async fn write_fetch_header(
164        &mut self,
165        header: &AnyFetchHeader,
166    ) -> Result<(), ConnectionError> {
167        let mut buf = Vec::new();
168        header.encode(&mut buf);
169        self.inner.write_all(&buf).await?;
170        Ok(())
171    }
172
173    /// Append a draft-14 subgroup object to the stream. Uses the stateful
174    /// reader installed by [`FramedSendStream::write_subgroup_header`] to
175    /// produce the correct delta-encoded object ID. Returns an error if
176    /// called before a subgroup header was written.
177    pub async fn write_subgroup_object(
178        &mut self,
179        object: &SubgroupObject,
180    ) -> Result<(), ConnectionError> {
181        let writer = self
182            .subgroup_io
183            .as_mut()
184            .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
185        let mut buf = Vec::new();
186        writer.write_object(object, &mut buf)?;
187        self.inner.write_all(&buf).await?;
188        Ok(())
189    }
190
191    /// Append a draft-14 fetch object to the stream. Fetch objects are
192    /// self-contained (each carries its own group/subgroup/object IDs and
193    /// priority), so no prior `write_fetch_header` bookkeeping is required.
194    pub async fn write_fetch_object(
195        &mut self,
196        object: &FetchObject,
197    ) -> Result<(), ConnectionError> {
198        let mut buf = Vec::new();
199        object.encode(&mut buf);
200        self.inner.write_all(&buf).await?;
201        Ok(())
202    }
203
204    /// Finish the stream (send FIN).
205    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
206        self.inner.finish()?;
207        Ok(())
208    }
209
210    /// Returns the draft version this stream is framed for.
211    pub fn draft(&self) -> DraftVersion {
212        self.draft
213    }
214}
215
216/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
217pub struct FramedRecvStream {
218    inner: RecvStream,
219    buf: BytesMut,
220    draft: DraftVersion,
221    /// Stateful subgroup object decoder. Initialized by
222    /// [`FramedRecvStream::read_subgroup_header`] and used by
223    /// [`FramedRecvStream::read_subgroup_object`] to track delta-encoded
224    /// object IDs and whether extension headers are present.
225    subgroup_io: Option<SubgroupObjectReader>,
226}
227
228impl FramedRecvStream {
229    /// Create a new framed receive stream for the given draft version.
230    pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
231        Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
232    }
233
234    /// Get the transport-level stream ID.
235    pub fn stream_id(&self) -> u64 {
236        self.inner.stream_id()
237    }
238
239    /// Read more data from the stream into the internal buffer.
240    async fn fill(&mut self) -> Result<bool, ConnectionError> {
241        let mut tmp = [0u8; 4096];
242        match self.inner.read(&mut tmp).await {
243            Ok(Some(n)) => {
244                self.buf.extend_from_slice(&tmp[..n]);
245                Ok(true)
246            }
247            Ok(None) => Ok(false),
248            Err(e) => Err(ConnectionError::Transport(e)),
249        }
250    }
251
252    /// Ensure at least `n` bytes are available in the buffer.
253    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
254        while self.buf.len() < n {
255            if !self.fill().await? {
256                return Err(ConnectionError::UnexpectedEnd);
257            }
258        }
259        Ok(())
260    }
261
262    /// Read a control message from the stream.
263    ///
264    /// When `capture_raw` is true, the returned tuple includes a clone of the
265    /// framed wire bytes (for observer emission). When false, the second
266    /// element is `None` and the payload clone is skipped.
267    pub async fn read_control(
268        &mut self,
269        capture_raw: bool,
270    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
271        // Read type ID varint
272        self.ensure(1).await?;
273        let type_len = varint_len(self.buf[0]);
274        self.ensure(type_len).await?;
275
276        let mut cursor = &self.buf[..type_len];
277        let _type_id = VarInt::decode(&mut cursor)?;
278
279        // Read payload length (16-bit BE for draft-11+, varint for earlier drafts)
280        let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
281            self.ensure(type_len + 2).await?;
282            let hi = self.buf[type_len] as usize;
283            let lo = self.buf[type_len + 1] as usize;
284            ((hi << 8) | lo, 2)
285        } else {
286            self.ensure(type_len + 1).await?;
287            let payload_len_start = type_len;
288            let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
289            self.ensure(type_len + payload_len_varint_len).await?;
290            let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
291            let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
292            (payload_len, payload_len_varint_len)
293        };
294
295        // Read full payload
296        let total = type_len + len_field_size + payload_len;
297        self.ensure(total).await?;
298
299        // Capture raw bytes only if requested (observer attached).
300        let raw = capture_raw.then(|| self.buf[..total].to_vec());
301
302        // Now decode the whole message
303        let mut frame = &self.buf[..total];
304        let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
305        self.buf.advance(total);
306        Ok((msg, raw))
307    }
308
309    /// Read a subgroup stream header. Also initializes the internal
310    /// delta-decoding state used by
311    /// [`FramedRecvStream::read_subgroup_object`].
312    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
313        self.ensure(1).await?;
314        loop {
315            let mut cursor = &self.buf[..];
316            match AnySubgroupHeader::decode(self.draft, &mut cursor) {
317                Ok(header) => {
318                    let consumed = self.buf.len() - cursor.remaining();
319                    self.buf.advance(consumed);
320                    if let AnySubgroupHeader::Draft14(ref d14) = header {
321                        self.subgroup_io = Some(SubgroupObjectReader::new(d14));
322                    }
323                    return Ok(header);
324                }
325                Err(CodecError::UnexpectedEnd) => {
326                    if !self.fill().await? {
327                        return Err(ConnectionError::UnexpectedEnd);
328                    }
329                }
330                Err(e) => return Err(ConnectionError::Codec(e)),
331            }
332        }
333    }
334
335    /// Read a fetch response header.
336    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
337        self.ensure(1).await?;
338        loop {
339            let mut cursor = &self.buf[..];
340            match AnyFetchHeader::decode(self.draft, &mut cursor) {
341                Ok(header) => {
342                    let consumed = self.buf.len() - cursor.remaining();
343                    self.buf.advance(consumed);
344                    return Ok(header);
345                }
346                Err(CodecError::UnexpectedEnd) => {
347                    if !self.fill().await? {
348                        return Err(ConnectionError::UnexpectedEnd);
349                    }
350                }
351                Err(e) => return Err(ConnectionError::Codec(e)),
352            }
353        }
354    }
355
356    /// Read the next draft-14 subgroup object from this stream. Uses the
357    /// stateful reader installed by
358    /// [`FramedRecvStream::read_subgroup_header`] to decode the delta-
359    /// encoded object ID and handle extension headers per the stream type.
360    /// Returns an error if called before a subgroup header was read.
361    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
362        if self.subgroup_io.is_none() {
363            return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
364        }
365        loop {
366            let reader = self.subgroup_io.as_mut().unwrap();
367            let mut probe = reader.clone();
368            let mut cursor = &self.buf[..];
369            match probe.read_object(&mut cursor) {
370                Ok(obj) => {
371                    let consumed = self.buf.len() - cursor.remaining();
372                    self.buf.advance(consumed);
373                    // Commit state from the successful probe decode.
374                    *reader = probe;
375                    return Ok(obj);
376                }
377                Err(CodecError::UnexpectedEnd) => {
378                    if !self.fill().await? {
379                        return Err(ConnectionError::UnexpectedEnd);
380                    }
381                }
382                Err(e) => return Err(ConnectionError::Codec(e)),
383            }
384        }
385    }
386
387    /// Read the next draft-14 fetch object from this stream. Fetch objects
388    /// are self-describing, so no prior `read_fetch_header` state is needed
389    /// to decode each object.
390    pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
391        loop {
392            let mut cursor = &self.buf[..];
393            match FetchObject::decode(&mut cursor) {
394                Ok(obj) => {
395                    let consumed = self.buf.len() - cursor.remaining();
396                    self.buf.advance(consumed);
397                    return Ok(obj);
398                }
399                Err(CodecError::UnexpectedEnd) => {
400                    if !self.fill().await? {
401                        return Err(ConnectionError::UnexpectedEnd);
402                    }
403                }
404                Err(e) => return Err(ConnectionError::Codec(e)),
405            }
406        }
407    }
408
409    /// Returns the draft version this stream is framed for.
410    pub fn draft(&self) -> DraftVersion {
411        self.draft
412    }
413}
414
415/// A live MoQT connection over QUIC or WebTransport, combining the endpoint
416/// state machine with actual network I/O.
417pub struct Connection {
418    transport: Transport,
419    endpoint: Endpoint,
420    draft: DraftVersion,
421    control_send: Option<FramedSendStream>,
422    control_recv: Option<FramedRecvStream>,
423    observer: Option<Box<dyn ConnectionObserver>>,
424}
425
426impl Connection {
427    /// Connect to a MoQT server as a client.
428    ///
429    /// Establishes a QUIC or WebTransport connection (based on `config.transport`),
430    /// opens a bidirectional control stream, performs the CLIENT_SETUP /
431    /// SERVER_SETUP handshake, and returns a ready-to-use connection.
432    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
433        let draft = config.draft;
434        let transport = match &config.transport {
435            TransportType::Quic => Self::connect_quic(addr, &config).await?,
436            TransportType::WebTransport { url } => {
437                let url = url.clone();
438                Self::connect_webtransport(&url, &config).await?
439            }
440        };
441
442        // Open bidirectional control stream
443        let (send, recv) = transport.open_bi().await?;
444        let mut control_send = FramedSendStream::new(send, draft);
445        let mut control_recv = FramedRecvStream::new(recv, draft);
446
447        // Perform setup handshake
448        let mut endpoint = Endpoint::new(Role::Client);
449        endpoint.connect()?;
450        let setup_msg = endpoint
451            .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
452        let any_setup = AnyControlMessage::Draft14(setup_msg);
453        let _raw_setup = control_send.write_control(&any_setup).await?;
454
455        let (server_setup, _raw_server_setup) = control_recv.read_control(false).await?;
456        // Unwrap to draft-14 for the endpoint (which is draft-14 only)
457        match &server_setup {
458            AnyControlMessage::Draft14(ControlMessage::ServerSetup(ref ss)) => {
459                endpoint.receive_server_setup(ss)?;
460            }
461            _ => {
462                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
463            }
464        }
465
466        let conn = Self {
467            transport,
468            endpoint,
469            draft,
470            control_send: Some(control_send),
471            control_recv: Some(control_recv),
472            observer: None,
473        };
474
475        // Emit SetupComplete event
476        if let Some(v) = conn.endpoint.negotiated_version() {
477            conn.emit(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
478        }
479
480        Ok(conn)
481    }
482
483    /// Establish a raw QUIC connection.
484    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
485        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
486            ConnectionError::InvalidAddress(e.to_string())
487        })?;
488
489        // Build TLS config
490        let mut tls_config = if config.skip_cert_verification {
491            rustls::ClientConfig::builder()
492                .dangerous()
493                .with_custom_certificate_verifier(Arc::new(SkipVerification))
494                .with_no_client_auth()
495        } else {
496            let mut roots = rustls::RootCertStore::empty();
497            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
498            for der in &config.ca_certs {
499                roots
500                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
501                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
502            }
503            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
504        };
505
506        tls_config.alpn_protocols = config.alpn();
507
508        let quic_config: quinn::crypto::rustls::QuicClientConfig =
509            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
510        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
511
512        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
513            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
514        quinn_endpoint.set_default_client_config(client_config);
515
516        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
517
518        let quic = quinn_endpoint
519            .connect(server_addr, &server_name)
520            .map_err(TransportError::from)?
521            .await
522            .map_err(TransportError::from)?;
523
524        Ok(Transport::Quic(QuicTransport::new(quic)))
525    }
526
527    /// Establish a WebTransport connection.
528    #[cfg(feature = "webtransport")]
529    async fn connect_webtransport(
530        url: &str,
531        config: &ClientConfig,
532    ) -> Result<Transport, ConnectionError> {
533        use crate::transport::webtransport::WebTransportTransport;
534
535        let wt_config = if config.skip_cert_verification {
536            wtransport::ClientConfig::builder()
537                .with_bind_default()
538                .with_no_cert_validation()
539                .build()
540        } else {
541            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
542        };
543
544        let endpoint = wtransport::Endpoint::client(wt_config)
545            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
546
547        let connection = endpoint
548            .connect(url)
549            .await
550            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
551
552        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
553    }
554
555    /// Stub for when the webtransport feature is not enabled.
556    #[cfg(not(feature = "webtransport"))]
557    async fn connect_webtransport(
558        _url: &str,
559        _config: &ClientConfig,
560    ) -> Result<Transport, ConnectionError> {
561        Err(ConnectionError::Transport(TransportError::Connect(
562            "webtransport feature not enabled".into(),
563        )))
564    }
565
566    // ── Observer ───────────────────────────────────────────────
567
568    /// Attach an observer to receive connection events.
569    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
570        self.observer = Some(observer);
571    }
572
573    /// Remove the observer.
574    pub fn clear_observer(&mut self) {
575        self.observer = None;
576    }
577
578    /// Emit an event to the observer, if one is attached.
579    fn emit(&self, event: ClientEvent) {
580        if let Some(ref obs) = self.observer {
581            obs.on_event_owned(event);
582        }
583    }
584
585    // ── Control message I/O ─────────────────────────────────
586
587    /// Send a control message on the control stream.
588    ///
589    /// Wraps the draft-14 message in `AnyControlMessage::Draft14` for framing.
590    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
591        let any = AnyControlMessage::Draft14(msg.clone());
592        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
593        let raw = send.write_control(&any).await?;
594        self.emit(ClientEvent::ControlMessage {
595            direction: Direction::Send,
596            message: any,
597            raw: Some(raw),
598        });
599        Ok(())
600    }
601
602    /// Read the next control message from the control stream.
603    ///
604    /// Returns the `AnyControlMessage` and also extracts the draft-14
605    /// `ControlMessage` for internal endpoint dispatch.
606    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
607        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
608        let capture_raw = self.observer.is_some();
609        let (any, raw) = recv.read_control(capture_raw).await?;
610        if capture_raw {
611            self.emit(ClientEvent::ControlMessage {
612                direction: Direction::Receive,
613                message: any.clone(),
614                raw,
615            });
616        }
617        // Unwrap to draft-14 for the endpoint
618        match any {
619            AnyControlMessage::Draft14(msg) => Ok(msg),
620            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
621        }
622    }
623
624    /// Read and dispatch the next incoming control message through the endpoint
625    /// state machine. Returns the decoded message for inspection.
626    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
627        let msg = self.recv_control().await?;
628        self.endpoint.receive_message(msg.clone())?;
629
630        // Emit draining event if this was a GoAway
631        if let ControlMessage::GoAway(ref ga) = msg {
632            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
633        }
634
635        Ok(msg)
636    }
637
638    // ── Subscribe flow ──────────────────────────────────────
639
640    /// Send a SUBSCRIBE and return the allocated request ID.
641    pub async fn subscribe(
642        &mut self,
643        track_namespace: TrackNamespace,
644        track_name: Vec<u8>,
645        subscriber_priority: u8,
646        group_order: GroupOrder,
647        filter_type: FilterType,
648    ) -> Result<VarInt, ConnectionError> {
649        let (req_id, msg) = self.endpoint.subscribe(
650            track_namespace,
651            track_name,
652            subscriber_priority,
653            group_order,
654            filter_type,
655        )?;
656        self.send_control(&msg).await?;
657        Ok(req_id)
658    }
659
660    /// Send an UNSUBSCRIBE for the given request ID.
661    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
662        let msg = self.endpoint.unsubscribe(request_id)?;
663        self.send_control(&msg).await
664    }
665
666    // ── Fetch flow ──────────────────────────────────────────
667
668    /// Send a FETCH and return the allocated request ID.
669    pub async fn fetch(
670        &mut self,
671        track_namespace: TrackNamespace,
672        track_name: Vec<u8>,
673        start_group: VarInt,
674        start_object: VarInt,
675    ) -> Result<VarInt, ConnectionError> {
676        let (req_id, msg) =
677            self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
678        self.send_control(&msg).await?;
679        Ok(req_id)
680    }
681
682    /// Send a FETCH_CANCEL for the given request ID.
683    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
684        let msg = self.endpoint.fetch_cancel(request_id)?;
685        self.send_control(&msg).await
686    }
687
688    // ── Namespace flows ─────────────────────────────────────
689
690    /// Send a SUBSCRIBE_NAMESPACE and return the request ID.
691    pub async fn subscribe_namespace(
692        &mut self,
693        track_namespace: TrackNamespace,
694    ) -> Result<VarInt, ConnectionError> {
695        let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
696        self.send_control(&msg).await?;
697        Ok(req_id)
698    }
699
700    /// Send a PUBLISH_NAMESPACE and return the request ID.
701    pub async fn publish_namespace(
702        &mut self,
703        track_namespace: TrackNamespace,
704    ) -> Result<VarInt, ConnectionError> {
705        let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
706        self.send_control(&msg).await?;
707        Ok(req_id)
708    }
709
710    // ── Track Status flow ────────────────────────────────────
711
712    /// Send a TRACK_STATUS and return the allocated request ID.
713    pub async fn track_status(
714        &mut self,
715        track_namespace: TrackNamespace,
716        track_name: Vec<u8>,
717    ) -> Result<VarInt, ConnectionError> {
718        let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name)?;
719        self.send_control(&msg).await?;
720        Ok(req_id)
721    }
722
723    // ── Publish flow (publisher side) ───────────────────────
724
725    /// Send a PUBLISH and return the allocated request ID.
726    pub async fn publish(
727        &mut self,
728        track_namespace: TrackNamespace,
729        track_name: Vec<u8>,
730        forward: Forward,
731    ) -> Result<VarInt, ConnectionError> {
732        let (req_id, msg) = self.endpoint.publish(track_namespace, track_name, forward)?;
733        self.send_control(&msg).await?;
734        Ok(req_id)
735    }
736
737    /// Send a PUBLISH_DONE for the given request ID.
738    pub async fn publish_done(
739        &mut self,
740        request_id: VarInt,
741        status_code: VarInt,
742        reason_phrase: Vec<u8>,
743    ) -> Result<(), ConnectionError> {
744        let msg = self.endpoint.send_publish_done(request_id, status_code, reason_phrase)?;
745        self.send_control(&msg).await
746    }
747
748    // ── Data streams ────────────────────────────────────────
749
750    /// Open a new unidirectional stream for sending subgroup data.
751    pub async fn open_subgroup_stream(
752        &self,
753        header: &AnySubgroupHeader,
754    ) -> Result<FramedSendStream, ConnectionError> {
755        let send = self.transport.open_uni().await?;
756        let mut framed = FramedSendStream::new(send, self.draft);
757        let sid = framed.stream_id();
758        framed.write_subgroup_header(header).await?;
759        self.emit(ClientEvent::StreamOpened {
760            direction: Direction::Send,
761            stream_kind: StreamKind::Subgroup,
762            stream_id: sid,
763        });
764        self.emit(ClientEvent::DataStreamHeader {
765            stream_id: sid,
766            direction: Direction::Send,
767            header: header.clone(),
768        });
769        Ok(framed)
770    }
771
772    /// Accept an incoming unidirectional data stream and read its subgroup
773    /// header.
774    pub async fn accept_subgroup_stream(
775        &self,
776    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
777        let recv = self.transport.accept_uni().await?;
778        let mut framed = FramedRecvStream::new(recv, self.draft);
779        let sid = framed.stream_id();
780        let header = framed.read_subgroup_header().await?;
781        self.emit(ClientEvent::StreamOpened {
782            direction: Direction::Receive,
783            stream_kind: StreamKind::Subgroup,
784            stream_id: sid,
785        });
786        self.emit(ClientEvent::DataStreamHeader {
787            stream_id: sid,
788            direction: Direction::Receive,
789            header: header.clone(),
790        });
791        Ok((header, framed))
792    }
793
794    /// Send an object via datagram.
795    pub fn send_datagram(
796        &self,
797        header: &AnyDatagramHeader,
798        payload: &[u8],
799    ) -> Result<(), ConnectionError> {
800        let mut buf = Vec::new();
801        header.encode(&mut buf);
802        buf.extend_from_slice(payload);
803        self.emit(ClientEvent::DatagramReceived {
804            direction: Direction::Send,
805            header: header.clone(),
806            payload_len: payload.len(),
807        });
808        self.transport.send_datagram(bytes::Bytes::from(buf))?;
809        Ok(())
810    }
811
812    /// Receive a datagram and decode its header.
813    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
814        let data = self.transport.recv_datagram().await?;
815        let mut cursor = &data[..];
816        let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
817        let consumed = data.len() - cursor.len();
818        let payload = data.slice(consumed..);
819        self.emit(ClientEvent::DatagramReceived {
820            direction: Direction::Receive,
821            header: header.clone(),
822            payload_len: payload.len(),
823        });
824        Ok((header, payload))
825    }
826
827    // ── Accessors ───────────────────────────────────────────
828
829    /// Access the underlying endpoint state machine.
830    pub fn endpoint(&self) -> &Endpoint {
831        &self.endpoint
832    }
833
834    /// Mutable access to the endpoint state machine.
835    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
836        &mut self.endpoint
837    }
838
839    /// Get the negotiated MoQT version.
840    pub fn negotiated_version(&self) -> Option<VarInt> {
841        self.endpoint.negotiated_version()
842    }
843
844    /// Returns the draft version this connection is using.
845    pub fn draft(&self) -> DraftVersion {
846        self.draft
847    }
848
849    /// Close the connection.
850    pub fn close(&self, code: u32, reason: &[u8]) {
851        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
852        self.transport.close(code, reason);
853    }
854}
855
856/// Determine the encoded length of a varint from its first byte.
857fn varint_len(first_byte: u8) -> usize {
858    1 << (first_byte >> 6)
859}
860
861/// TLS certificate verifier that skips all verification (for testing only).
862#[derive(Debug)]
863struct SkipVerification;
864
865impl rustls::client::danger::ServerCertVerifier for SkipVerification {
866    fn verify_server_cert(
867        &self,
868        _end_entity: &rustls::pki_types::CertificateDer<'_>,
869        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
870        _server_name: &rustls::pki_types::ServerName<'_>,
871        _ocsp_response: &[u8],
872        _now: rustls::pki_types::UnixTime,
873    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
874        Ok(rustls::client::danger::ServerCertVerified::assertion())
875    }
876
877    fn verify_tls12_signature(
878        &self,
879        _message: &[u8],
880        _cert: &rustls::pki_types::CertificateDer<'_>,
881        _dcs: &rustls::DigitallySignedStruct,
882    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
883        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
884    }
885
886    fn verify_tls13_signature(
887        &self,
888        _message: &[u8],
889        _cert: &rustls::pki_types::CertificateDer<'_>,
890        _dcs: &rustls::DigitallySignedStruct,
891    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
892        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
893    }
894
895    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
896        vec![
897            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
898            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
899            rustls::SignatureScheme::ED25519,
900            rustls::SignatureScheme::RSA_PSS_SHA256,
901            rustls::SignatureScheme::RSA_PSS_SHA384,
902            rustls::SignatureScheme::RSA_PSS_SHA512,
903        ]
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910
911    #[test]
912    fn varint_len_single_byte() {
913        // 0b00xxxxxx -> 1 byte
914        assert_eq!(varint_len(0x00), 1);
915        assert_eq!(varint_len(0x3F), 1);
916    }
917
918    #[test]
919    fn varint_len_two_bytes() {
920        // 0b01xxxxxx -> 2 bytes
921        assert_eq!(varint_len(0x40), 2);
922        assert_eq!(varint_len(0x7F), 2);
923    }
924
925    #[test]
926    fn varint_len_four_bytes() {
927        // 0b10xxxxxx -> 4 bytes
928        assert_eq!(varint_len(0x80), 4);
929        assert_eq!(varint_len(0xBF), 4);
930    }
931
932    #[test]
933    fn varint_len_eight_bytes() {
934        // 0b11xxxxxx -> 8 bytes
935        assert_eq!(varint_len(0xC0), 8);
936        assert_eq!(varint_len(0xFF), 8);
937    }
938
939    #[test]
940    fn client_config_supported_versions_draft14() {
941        let config = ClientConfig {
942            draft: DraftVersion::Draft14,
943            additional_versions: Vec::new(),
944            transport: TransportType::Quic,
945            skip_cert_verification: false,
946            ca_certs: Vec::new(),
947            setup_parameters: Vec::new(),
948        };
949        let versions = config.supported_versions();
950        assert_eq!(versions.len(), 1);
951        assert_eq!(versions[0].into_inner(), 0xff000000 + 14);
952    }
953
954    #[test]
955    fn client_config_supported_versions_draft07() {
956        let config = ClientConfig {
957            draft: DraftVersion::Draft07,
958            additional_versions: Vec::new(),
959            transport: TransportType::Quic,
960            skip_cert_verification: false,
961            ca_certs: Vec::new(),
962            setup_parameters: Vec::new(),
963        };
964        let versions = config.supported_versions();
965        assert_eq!(versions.len(), 1);
966        assert_eq!(versions[0].into_inner(), 0xff000000 + 7);
967    }
968
969    #[test]
970    fn client_config_alpn_quic() {
971        let config = ClientConfig {
972            draft: DraftVersion::Draft14,
973            additional_versions: Vec::new(),
974            transport: TransportType::Quic,
975            skip_cert_verification: false,
976            ca_certs: Vec::new(),
977            setup_parameters: Vec::new(),
978        };
979        assert_eq!(config.alpn(), vec![b"moq-00".to_vec()]);
980    }
981
982    #[test]
983    fn client_config_alpn_webtransport() {
984        let config = ClientConfig {
985            draft: DraftVersion::Draft14,
986            additional_versions: Vec::new(),
987            transport: TransportType::WebTransport { url: "https://example.com".to_string() },
988            skip_cert_verification: false,
989            ca_certs: Vec::new(),
990            setup_parameters: Vec::new(),
991        };
992        assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
993    }
994
995    #[test]
996    fn moqt_alpn_value() {
997        assert_eq!(MOQT_ALPN, b"moq-00");
998    }
999
1000    #[test]
1001    fn transport_type_debug() {
1002        let quic = TransportType::Quic;
1003        assert!(format!("{quic:?}").contains("Quic"));
1004
1005        let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
1006        assert!(format!("{wt:?}").contains("WebTransport"));
1007    }
1008}