Skip to main content

moqtap_client/draft12/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft12::endpoint::{Endpoint, EndpointError};
6use crate::draft12::event::{ClientEvent, Direction, FetchObject, StreamKind, SubgroupObject};
7use crate::draft12::observer::ConnectionObserver;
8use crate::transport::quic::QuicTransport;
9use crate::transport::{RecvStream, SendStream, Transport, TransportError};
10use moqtap_codec::dispatch::{
11    AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
12};
13use moqtap_codec::draft12::data_stream::{FetchObjectHeader, ObjectHeader};
14use moqtap_codec::draft12::message::ControlMessage;
15use moqtap_codec::error::CodecError;
16use moqtap_codec::types::*;
17use moqtap_codec::varint::VarInt;
18use moqtap_codec::version::DraftVersion;
19
20/// MoQT ALPN identifier (used by raw QUIC transport).
21pub const MOQT_ALPN: &[u8] = b"moq-00";
22
23/// Errors from the draft-12 connection layer.
24#[derive(Debug, thiserror::Error)]
25pub enum ConnectionError {
26    /// Endpoint state machine error.
27    #[error("endpoint error: {0}")]
28    Endpoint(#[from] EndpointError),
29    /// Wire codec error.
30    #[error("codec error: {0}")]
31    Codec(#[from] CodecError),
32    /// Transport-level error.
33    #[error("transport error: {0}")]
34    Transport(#[from] TransportError),
35    /// Variable-length integer decoding error.
36    #[error("varint error: {0}")]
37    VarInt(#[from] moqtap_codec::varint::VarIntError),
38    /// Control stream was not opened.
39    #[error("control stream not open")]
40    NoControlStream,
41    /// Stream ended before a complete message was read.
42    #[error("unexpected end of stream")]
43    UnexpectedEnd,
44    /// Stream was finished by the peer.
45    #[error("stream finished")]
46    StreamFinished,
47    /// Invalid server address string.
48    #[error("invalid server address: {0}")]
49    InvalidAddress(String),
50    /// TLS configuration error.
51    #[error("TLS config error: {0}")]
52    TlsConfig(String),
53}
54
55/// Transport type for the connection.
56#[derive(Debug, Clone)]
57pub enum TransportType {
58    /// Raw QUIC via quinn. The `addr` field should be `host:port`.
59    Quic,
60    /// WebTransport via wtransport. The `url` field is the WebTransport URL.
61    WebTransport {
62        /// The WebTransport endpoint URL (e.g., `https://host:port/path`).
63        url: String,
64    },
65}
66
67/// Configuration for a draft-12 MoQT client connection.
68pub struct ClientConfig {
69    /// Additional draft versions to offer in CLIENT_SETUP (draft-12 is always
70    /// offered first).
71    pub additional_versions: Vec<DraftVersion>,
72    /// The transport type (QUIC or WebTransport).
73    pub transport: TransportType,
74    /// Whether to skip TLS certificate verification (for testing).
75    pub skip_cert_verification: bool,
76    /// Custom CA certificates to trust (DER-encoded).
77    pub ca_certs: Vec<Vec<u8>>,
78    /// Setup parameters to include in CLIENT_SETUP (e.g., auth tokens).
79    pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
80}
81
82impl ClientConfig {
83    /// Returns the MoQT version varints for the CLIENT_SETUP message.
84    /// Draft-12 first, then any additional versions.
85    pub fn supported_versions(&self) -> Vec<VarInt> {
86        let mut versions = vec![DraftVersion::Draft12.version_varint()];
87        for v in &self.additional_versions {
88            let varint = v.version_varint();
89            if !versions.contains(&varint) {
90                versions.push(varint);
91            }
92        }
93        versions
94    }
95
96    /// Returns the ALPN protocol identifiers for the transport.
97    pub fn alpn(&self) -> Vec<Vec<u8>> {
98        match &self.transport {
99            TransportType::Quic => vec![DraftVersion::Draft12.quic_alpn().to_vec()],
100            TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
101        }
102    }
103}
104
105/// A framed writer for a send stream. Handles MoQT length-prefixed framing.
106pub struct FramedSendStream {
107    inner: SendStream,
108}
109
110impl FramedSendStream {
111    /// Create a new framed send stream.
112    pub fn new(inner: SendStream) -> Self {
113        Self { inner }
114    }
115
116    /// Get the transport-level stream ID.
117    pub fn stream_id(&self) -> u64 {
118        self.inner.stream_id()
119    }
120
121    /// Write a control message to the stream with type+length framing.
122    /// Returns the raw bytes that were written (for event capture).
123    pub async fn write_control(
124        &mut self,
125        msg: &AnyControlMessage,
126    ) -> Result<Vec<u8>, ConnectionError> {
127        let mut buf = Vec::new();
128        msg.encode(&mut buf)?;
129        self.inner.write_all(&buf).await?;
130        Ok(buf)
131    }
132
133    /// Write a subgroup stream header.
134    pub async fn write_subgroup_header(
135        &mut self,
136        header: &AnySubgroupHeader,
137    ) -> Result<(), ConnectionError> {
138        let mut buf = Vec::new();
139        header.encode(&mut buf);
140        self.inner.write_all(&buf).await?;
141        Ok(())
142    }
143
144    /// Write a fetch response header.
145    pub async fn write_fetch_header(
146        &mut self,
147        header: &AnyFetchHeader,
148    ) -> Result<(), ConnectionError> {
149        let mut buf = Vec::new();
150        header.encode(&mut buf);
151        self.inner.write_all(&buf).await?;
152        Ok(())
153    }
154
155    /// Append a draft-12 subgroup object (header + payload) to the stream.
156    /// Draft-12 subgroup objects may carry extension headers but the writer
157    /// path defers to the codec `ObjectHeader::encode` helper.
158    pub async fn write_subgroup_object(
159        &mut self,
160        object: &SubgroupObject,
161    ) -> Result<(), ConnectionError> {
162        let mut buf = Vec::new();
163        object.header.encode(&mut buf);
164        buf.extend_from_slice(&object.payload);
165        self.inner.write_all(&buf).await?;
166        Ok(())
167    }
168
169    /// Append a draft-12 fetch object (header + payload) to the stream.
170    pub async fn write_fetch_object(
171        &mut self,
172        object: &FetchObject,
173    ) -> Result<(), ConnectionError> {
174        let mut buf = Vec::new();
175        object.header.encode(&mut buf);
176        buf.extend_from_slice(&object.payload);
177        self.inner.write_all(&buf).await?;
178        Ok(())
179    }
180
181    /// Finish the stream (send FIN).
182    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
183        self.inner.finish()?;
184        Ok(())
185    }
186}
187
188/// A framed reader for a recv stream. Handles MoQT varint-length decoding.
189pub struct FramedRecvStream {
190    inner: RecvStream,
191    buf: BytesMut,
192}
193
194impl FramedRecvStream {
195    /// Create a new framed receive stream.
196    pub fn new(inner: RecvStream) -> Self {
197        Self { inner, buf: BytesMut::with_capacity(4096) }
198    }
199
200    /// Get the transport-level stream ID.
201    pub fn stream_id(&self) -> u64 {
202        self.inner.stream_id()
203    }
204
205    /// Read more data from the stream into the internal buffer.
206    async fn fill(&mut self) -> Result<bool, ConnectionError> {
207        let mut tmp = [0u8; 4096];
208        match self.inner.read(&mut tmp).await {
209            Ok(Some(n)) => {
210                self.buf.extend_from_slice(&tmp[..n]);
211                Ok(true)
212            }
213            Ok(None) => Ok(false),
214            Err(e) => Err(ConnectionError::Transport(e)),
215        }
216    }
217
218    /// Ensure at least `n` bytes are available in the buffer.
219    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
220        while self.buf.len() < n {
221            if !self.fill().await? {
222                return Err(ConnectionError::UnexpectedEnd);
223            }
224        }
225        Ok(())
226    }
227
228    /// Read a control message from the stream.
229    ///
230    /// When `capture_raw` is true, the returned tuple includes a clone of the
231    /// framed wire bytes (for observer emission). When false, the second
232    /// element is `None` and the payload clone is skipped.
233    pub async fn read_control(
234        &mut self,
235        capture_raw: bool,
236    ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
237        // Read type ID varint
238        self.ensure(1).await?;
239        let type_len = varint_len(self.buf[0]);
240        self.ensure(type_len).await?;
241
242        let mut cursor = &self.buf[..type_len];
243        let _type_id = VarInt::decode(&mut cursor)?;
244
245        // Draft-12 uses a fixed 16-bit length after the type id.
246        let len_field_size = 2;
247        self.ensure(type_len + len_field_size).await?;
248        let payload_len = u16::from_be_bytes([self.buf[type_len], self.buf[type_len + 1]]) as usize;
249
250        // Read full payload
251        let total = type_len + len_field_size + payload_len;
252        self.ensure(total).await?;
253
254        // Capture raw bytes only if requested (observer attached).
255        let raw = capture_raw.then(|| self.buf[..total].to_vec());
256
257        // Now decode the whole message using the draft-12 dispatcher
258        let mut frame = &self.buf[..total];
259        let msg = AnyControlMessage::decode(DraftVersion::Draft12, &mut frame)?;
260        self.buf.advance(total);
261        Ok((msg, raw))
262    }
263
264    /// Read a subgroup stream header.
265    pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
266        self.ensure(1).await?;
267        loop {
268            let mut cursor = &self.buf[..];
269            match AnySubgroupHeader::decode(DraftVersion::Draft12, &mut cursor) {
270                Ok(header) => {
271                    let consumed = self.buf.len() - cursor.remaining();
272                    self.buf.advance(consumed);
273                    return Ok(header);
274                }
275                Err(CodecError::UnexpectedEnd) => {
276                    if !self.fill().await? {
277                        return Err(ConnectionError::UnexpectedEnd);
278                    }
279                }
280                Err(e) => return Err(ConnectionError::Codec(e)),
281            }
282        }
283    }
284
285    /// Read a fetch response header.
286    pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
287        self.ensure(1).await?;
288        loop {
289            let mut cursor = &self.buf[..];
290            match AnyFetchHeader::decode(DraftVersion::Draft12, &mut cursor) {
291                Ok(header) => {
292                    let consumed = self.buf.len() - cursor.remaining();
293                    self.buf.advance(consumed);
294                    return Ok(header);
295                }
296                Err(CodecError::UnexpectedEnd) => {
297                    if !self.fill().await? {
298                        return Err(ConnectionError::UnexpectedEnd);
299                    }
300                }
301                Err(e) => return Err(ConnectionError::Codec(e)),
302            }
303        }
304    }
305
306    /// Read the next draft-12 subgroup object (header + payload). The codec
307    /// helper `ObjectHeader::decode` is called without extensions by default;
308    /// callers that care about extensions should drive decoding explicitly.
309    pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
310        loop {
311            let mut cursor = &self.buf[..];
312            match ObjectHeader::decode(&mut cursor) {
313                Ok(header) => {
314                    let header_consumed = self.buf.len() - cursor.remaining();
315                    let payload_len = header.payload_length.into_inner() as usize;
316                    let total = header_consumed + payload_len;
317                    if self.buf.len() < total {
318                        if !self.fill().await? {
319                            return Err(ConnectionError::UnexpectedEnd);
320                        }
321                        continue;
322                    }
323                    let payload = self.buf[header_consumed..total].to_vec();
324                    self.buf.advance(total);
325                    return Ok(SubgroupObject { header, payload });
326                }
327                Err(CodecError::UnexpectedEnd) => {
328                    if !self.fill().await? {
329                        return Err(ConnectionError::UnexpectedEnd);
330                    }
331                }
332                Err(e) => return Err(ConnectionError::Codec(e)),
333            }
334        }
335    }
336
337    /// Read the next draft-12 fetch object (header + payload).
338    pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
339        loop {
340            let mut cursor = &self.buf[..];
341            match FetchObjectHeader::decode(&mut cursor) {
342                Ok(header) => {
343                    let header_consumed = self.buf.len() - cursor.remaining();
344                    let payload_len = header.payload_length.into_inner() as usize;
345                    let total = header_consumed + payload_len;
346                    if self.buf.len() < total {
347                        if !self.fill().await? {
348                            return Err(ConnectionError::UnexpectedEnd);
349                        }
350                        continue;
351                    }
352                    let payload = self.buf[header_consumed..total].to_vec();
353                    self.buf.advance(total);
354                    return Ok(FetchObject { header, payload });
355                }
356                Err(CodecError::UnexpectedEnd) => {
357                    if !self.fill().await? {
358                        return Err(ConnectionError::UnexpectedEnd);
359                    }
360                }
361                Err(e) => return Err(ConnectionError::Codec(e)),
362            }
363        }
364    }
365}
366
367/// A live draft-12 MoQT connection over QUIC or WebTransport.
368pub struct Connection {
369    transport: Transport,
370    endpoint: Endpoint,
371    control_send: Option<FramedSendStream>,
372    control_recv: Option<FramedRecvStream>,
373    observer: Option<Box<dyn ConnectionObserver>>,
374}
375
376impl Connection {
377    /// Connect to a draft-12 MoQT server as a client.
378    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
379        let transport = match &config.transport {
380            TransportType::Quic => Self::connect_quic(addr, &config).await?,
381            TransportType::WebTransport { url } => {
382                let url = url.clone();
383                Self::connect_webtransport(&url, &config).await?
384            }
385        };
386
387        // Open bidirectional control stream
388        let (send, recv) = transport.open_bi().await?;
389        let mut control_send = FramedSendStream::new(send);
390        let mut control_recv = FramedRecvStream::new(recv);
391
392        // Perform setup handshake
393        let mut endpoint = Endpoint::new();
394        endpoint.connect()?;
395        let setup_msg = endpoint
396            .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
397        let any_setup = AnyControlMessage::Draft12(setup_msg);
398        let _raw_setup = control_send.write_control(&any_setup).await?;
399
400        let (server_setup, _raw_server_setup) = control_recv.read_control(false).await?;
401        match &server_setup {
402            AnyControlMessage::Draft12(ControlMessage::ServerSetup(ref ss)) => {
403                endpoint.receive_server_setup(ss)?;
404            }
405            _ => {
406                return Err(ConnectionError::Endpoint(EndpointError::NotActive));
407            }
408        }
409
410        let conn = Self {
411            transport,
412            endpoint,
413            control_send: Some(control_send),
414            control_recv: Some(control_recv),
415            observer: None,
416        };
417
418        if let Some(v) = conn.endpoint.negotiated_version() {
419            conn.emit(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
420        }
421
422        Ok(conn)
423    }
424
425    /// Establish a raw QUIC connection.
426    async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
427        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
428            ConnectionError::InvalidAddress(e.to_string())
429        })?;
430
431        let mut tls_config = if config.skip_cert_verification {
432            rustls::ClientConfig::builder()
433                .dangerous()
434                .with_custom_certificate_verifier(Arc::new(SkipVerification))
435                .with_no_client_auth()
436        } else {
437            let mut roots = rustls::RootCertStore::empty();
438            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
439            for der in &config.ca_certs {
440                roots
441                    .add(rustls::pki_types::CertificateDer::from(der.clone()))
442                    .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
443            }
444            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
445        };
446
447        tls_config.alpn_protocols = config.alpn();
448
449        let quic_config: quinn::crypto::rustls::QuicClientConfig =
450            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
451        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
452
453        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
454            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
455        quinn_endpoint.set_default_client_config(client_config);
456
457        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
458
459        let quic = quinn_endpoint
460            .connect(server_addr, &server_name)
461            .map_err(TransportError::from)?
462            .await
463            .map_err(TransportError::from)?;
464
465        Ok(Transport::Quic(QuicTransport::new(quic)))
466    }
467
468    /// Establish a WebTransport connection.
469    #[cfg(feature = "webtransport")]
470    async fn connect_webtransport(
471        url: &str,
472        config: &ClientConfig,
473    ) -> Result<Transport, ConnectionError> {
474        use crate::transport::webtransport::WebTransportTransport;
475
476        let wt_config = if config.skip_cert_verification {
477            wtransport::ClientConfig::builder()
478                .with_bind_default()
479                .with_no_cert_validation()
480                .build()
481        } else {
482            wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
483        };
484
485        let endpoint = wtransport::Endpoint::client(wt_config)
486            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
487
488        let connection = endpoint
489            .connect(url)
490            .await
491            .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
492
493        Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
494    }
495
496    /// Stub for when the webtransport feature is not enabled.
497    #[cfg(not(feature = "webtransport"))]
498    async fn connect_webtransport(
499        _url: &str,
500        _config: &ClientConfig,
501    ) -> Result<Transport, ConnectionError> {
502        Err(ConnectionError::Transport(TransportError::Connect(
503            "webtransport feature not enabled".into(),
504        )))
505    }
506
507    // ── Observer ───────────────────────────────────────────────
508
509    /// Attach an observer to receive connection events.
510    pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
511        self.observer = Some(observer);
512    }
513
514    /// Remove the observer.
515    pub fn clear_observer(&mut self) {
516        self.observer = None;
517    }
518
519    /// Emit an event to the observer, if one is attached.
520    fn emit(&self, event: ClientEvent) {
521        if let Some(ref obs) = self.observer {
522            obs.on_event_owned(event);
523        }
524    }
525
526    // ── Control message I/O ─────────────────────────────────
527
528    /// Send a control message on the control stream.
529    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
530        let any = AnyControlMessage::Draft12(msg.clone());
531        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
532        let raw = send.write_control(&any).await?;
533        self.emit(ClientEvent::ControlMessage {
534            direction: Direction::Send,
535            message: any,
536            raw: Some(raw),
537        });
538        Ok(())
539    }
540
541    /// Read the next control message from the control stream.
542    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
543        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
544        let capture_raw = self.observer.is_some();
545        let (any, raw) = recv.read_control(capture_raw).await?;
546        if capture_raw {
547            self.emit(ClientEvent::ControlMessage {
548                direction: Direction::Receive,
549                message: any.clone(),
550                raw,
551            });
552        }
553        match any {
554            AnyControlMessage::Draft12(msg) => Ok(msg),
555            _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
556        }
557    }
558
559    /// Read and dispatch the next incoming control message through the endpoint
560    /// state machine. Returns the decoded message for inspection.
561    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
562        let msg = self.recv_control().await?;
563        self.endpoint.receive_message(msg.clone())?;
564
565        if let ControlMessage::GoAway(ref ga) = msg {
566            self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
567        }
568
569        Ok(msg)
570    }
571
572    // ── Subscribe flow ──────────────────────────────────────
573
574    /// Send a SUBSCRIBE and return the allocated request ID.
575    ///
576    /// Draft-12: the subscriber no longer chooses a track alias. The alias is
577    /// returned by the publisher in SUBSCRIBE_OK and can be retrieved later
578    /// from the endpoint via `endpoint().track_alias_for(request_id)`.
579    pub async fn subscribe(
580        &mut self,
581        track_namespace: TrackNamespace,
582        track_name: Vec<u8>,
583        subscriber_priority: u8,
584        group_order: VarInt,
585        filter_type: VarInt,
586    ) -> Result<VarInt, ConnectionError> {
587        let (req_id, msg) = self.endpoint.subscribe(
588            track_namespace,
589            track_name,
590            subscriber_priority,
591            group_order,
592            filter_type,
593        )?;
594        self.send_control(&msg).await?;
595        Ok(req_id)
596    }
597
598    /// Send an UNSUBSCRIBE for the given request ID.
599    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
600        let msg = self.endpoint.unsubscribe(request_id)?;
601        self.send_control(&msg).await
602    }
603
604    // ── Fetch flow ──────────────────────────────────────────
605
606    /// Send a FETCH and return the allocated request ID.
607    #[allow(clippy::too_many_arguments)]
608    pub async fn fetch(
609        &mut self,
610        track_namespace: TrackNamespace,
611        track_name: Vec<u8>,
612        subscriber_priority: u8,
613        group_order: VarInt,
614        start_group: VarInt,
615        start_object: VarInt,
616        end_group: VarInt,
617        end_object: VarInt,
618    ) -> Result<VarInt, ConnectionError> {
619        let (req_id, msg) = self.endpoint.fetch(
620            track_namespace,
621            track_name,
622            subscriber_priority,
623            group_order,
624            start_group,
625            start_object,
626            end_group,
627            end_object,
628        )?;
629        self.send_control(&msg).await?;
630        Ok(req_id)
631    }
632
633    /// Send a FETCH_CANCEL for the given request ID.
634    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
635        let msg = self.endpoint.fetch_cancel(request_id)?;
636        self.send_control(&msg).await
637    }
638
639    // ── Namespace flows ─────────────────────────────────────
640
641    /// Send a SUBSCRIBE_ANNOUNCES. Returns the allocated request ID.
642    pub async fn subscribe_announces(
643        &mut self,
644        track_namespace_prefix: TrackNamespace,
645    ) -> Result<VarInt, ConnectionError> {
646        let (req_id, msg) = self.endpoint.subscribe_announces(track_namespace_prefix)?;
647        self.send_control(&msg).await?;
648        Ok(req_id)
649    }
650
651    /// Send an ANNOUNCE. Returns the allocated request ID.
652    pub async fn announce(
653        &mut self,
654        track_namespace: TrackNamespace,
655    ) -> Result<VarInt, ConnectionError> {
656        let (req_id, msg) = self.endpoint.announce(track_namespace)?;
657        self.send_control(&msg).await?;
658        Ok(req_id)
659    }
660
661    /// Send an UNANNOUNCE.
662    pub async fn unannounce(
663        &mut self,
664        track_namespace: TrackNamespace,
665    ) -> Result<(), ConnectionError> {
666        let msg = self.endpoint.unannounce(track_namespace)?;
667        self.send_control(&msg).await
668    }
669
670    // ── Track Status flow ────────────────────────────────────
671
672    /// Send a TRACK_STATUS_REQUEST. Returns the allocated request ID.
673    pub async fn track_status_request(
674        &mut self,
675        track_namespace: TrackNamespace,
676        track_name: Vec<u8>,
677    ) -> Result<VarInt, ConnectionError> {
678        let (req_id, msg) = self.endpoint.track_status_request(track_namespace, track_name)?;
679        self.send_control(&msg).await?;
680        Ok(req_id)
681    }
682
683    // ── Data streams ────────────────────────────────────────
684
685    /// Open a new unidirectional stream for sending subgroup data.
686    pub async fn open_subgroup_stream(
687        &self,
688        header: &AnySubgroupHeader,
689    ) -> Result<FramedSendStream, ConnectionError> {
690        let send = self.transport.open_uni().await?;
691        let mut framed = FramedSendStream::new(send);
692        let sid = framed.stream_id();
693        framed.write_subgroup_header(header).await?;
694        self.emit(ClientEvent::StreamOpened {
695            direction: Direction::Send,
696            stream_kind: StreamKind::Subgroup,
697            stream_id: sid,
698        });
699        self.emit(ClientEvent::DataStreamHeader {
700            stream_id: sid,
701            direction: Direction::Send,
702            header: header.clone(),
703        });
704        Ok(framed)
705    }
706
707    /// Accept an incoming unidirectional data stream and read its subgroup
708    /// header.
709    pub async fn accept_subgroup_stream(
710        &self,
711    ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
712        let recv = self.transport.accept_uni().await?;
713        let mut framed = FramedRecvStream::new(recv);
714        let sid = framed.stream_id();
715        let header = framed.read_subgroup_header().await?;
716        self.emit(ClientEvent::StreamOpened {
717            direction: Direction::Receive,
718            stream_kind: StreamKind::Subgroup,
719            stream_id: sid,
720        });
721        self.emit(ClientEvent::DataStreamHeader {
722            stream_id: sid,
723            direction: Direction::Receive,
724            header: header.clone(),
725        });
726        Ok((header, framed))
727    }
728
729    /// Send an object via datagram.
730    pub fn send_datagram(
731        &self,
732        header: &AnyDatagramHeader,
733        payload: &[u8],
734    ) -> Result<(), ConnectionError> {
735        let mut buf = Vec::new();
736        header.encode(&mut buf);
737        buf.extend_from_slice(payload);
738        self.emit(ClientEvent::DatagramReceived {
739            direction: Direction::Send,
740            header: header.clone(),
741            payload_len: payload.len(),
742        });
743        self.transport.send_datagram(bytes::Bytes::from(buf))?;
744        Ok(())
745    }
746
747    /// Receive a datagram and decode its header.
748    pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
749        let data = self.transport.recv_datagram().await?;
750        let mut cursor = &data[..];
751        let header = AnyDatagramHeader::decode(DraftVersion::Draft12, &mut cursor)?;
752        let consumed = data.len() - cursor.len();
753        let payload = data.slice(consumed..);
754        self.emit(ClientEvent::DatagramReceived {
755            direction: Direction::Receive,
756            header: header.clone(),
757            payload_len: payload.len(),
758        });
759        Ok((header, payload))
760    }
761
762    // ── Accessors ───────────────────────────────────────────
763
764    /// Access the underlying endpoint state machine.
765    pub fn endpoint(&self) -> &Endpoint {
766        &self.endpoint
767    }
768
769    /// Mutable access to the endpoint state machine.
770    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
771        &mut self.endpoint
772    }
773
774    /// Get the negotiated MoQT version.
775    pub fn negotiated_version(&self) -> Option<VarInt> {
776        self.endpoint.negotiated_version()
777    }
778
779    /// Close the connection.
780    pub fn close(&self, code: u32, reason: &[u8]) {
781        self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
782        self.transport.close(code, reason);
783    }
784}
785
786/// Determine the encoded length of a varint from its first byte.
787fn varint_len(first_byte: u8) -> usize {
788    1 << (first_byte >> 6)
789}
790
791/// TLS certificate verifier that skips all verification (for testing only).
792#[derive(Debug)]
793struct SkipVerification;
794
795impl rustls::client::danger::ServerCertVerifier for SkipVerification {
796    fn verify_server_cert(
797        &self,
798        _end_entity: &rustls::pki_types::CertificateDer<'_>,
799        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
800        _server_name: &rustls::pki_types::ServerName<'_>,
801        _ocsp_response: &[u8],
802        _now: rustls::pki_types::UnixTime,
803    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
804        Ok(rustls::client::danger::ServerCertVerified::assertion())
805    }
806
807    fn verify_tls12_signature(
808        &self,
809        _message: &[u8],
810        _cert: &rustls::pki_types::CertificateDer<'_>,
811        _dcs: &rustls::DigitallySignedStruct,
812    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
813        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
814    }
815
816    fn verify_tls13_signature(
817        &self,
818        _message: &[u8],
819        _cert: &rustls::pki_types::CertificateDer<'_>,
820        _dcs: &rustls::DigitallySignedStruct,
821    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
822        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
823    }
824
825    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
826        vec![
827            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
828            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
829            rustls::SignatureScheme::ED25519,
830            rustls::SignatureScheme::RSA_PSS_SHA256,
831            rustls::SignatureScheme::RSA_PSS_SHA384,
832            rustls::SignatureScheme::RSA_PSS_SHA512,
833        ]
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn client_config_supported_versions_default() {
843        let config = ClientConfig {
844            additional_versions: Vec::new(),
845            transport: TransportType::Quic,
846            skip_cert_verification: false,
847            ca_certs: Vec::new(),
848            setup_parameters: Vec::new(),
849        };
850        let versions = config.supported_versions();
851        assert_eq!(versions.len(), 1);
852        assert_eq!(versions[0].into_inner(), 0xff000000 + 12);
853    }
854
855    #[test]
856    fn client_config_alpn_quic() {
857        let config = ClientConfig {
858            additional_versions: Vec::new(),
859            transport: TransportType::Quic,
860            skip_cert_verification: false,
861            ca_certs: Vec::new(),
862            setup_parameters: Vec::new(),
863        };
864        assert_eq!(config.alpn(), vec![DraftVersion::Draft12.quic_alpn().to_vec()]);
865    }
866
867    #[test]
868    fn moqt_alpn_value() {
869        assert_eq!(MOQT_ALPN, b"moq-00");
870    }
871}