Skip to main content

moqtap_client/
connection.rs

1use std::sync::Arc;
2
3use bytes::{Buf, BytesMut};
4
5use moqtap_codec::data_stream::{DatagramHeader, FetchHeader, ObjectHeader, SubgroupHeader};
6use moqtap_codec::message::{CodecError, ControlMessage};
7use moqtap_codec::types::*;
8use moqtap_codec::varint::VarInt;
9use moqtap_session::request_id::Role;
10
11use crate::endpoint::{Endpoint, EndpointError};
12
13/// MoQT ALPN identifier (draft-14).
14pub const MOQT_ALPN: &[u8] = b"moq-00";
15
16/// Errors from the connection layer.
17#[derive(Debug, thiserror::Error)]
18pub enum ConnectionError {
19    #[error("endpoint error: {0}")]
20    Endpoint(#[from] EndpointError),
21    #[error("codec error: {0}")]
22    Codec(#[from] CodecError),
23    #[error("quinn connection error: {0}")]
24    Connection(#[from] quinn::ConnectionError),
25    #[error("quinn write error: {0}")]
26    Write(#[from] quinn::WriteError),
27    #[error("quinn read error: {0}")]
28    Read(#[from] quinn::ReadExactError),
29    #[error("quinn connect error: {0}")]
30    Connect(#[from] quinn::ConnectError),
31    #[error("closed stream: {0}")]
32    ClosedStream(#[from] quinn::ClosedStream),
33    #[error("send datagram error: {0}")]
34    SendDatagram(#[from] quinn::SendDatagramError),
35    #[error("varint error: {0}")]
36    VarInt(#[from] moqtap_codec::varint::VarIntError),
37    #[error("control stream not open")]
38    NoControlStream,
39    #[error("unexpected end of stream")]
40    UnexpectedEnd,
41    #[error("stream finished")]
42    StreamFinished,
43    #[error("invalid server address: {0}")]
44    InvalidAddress(String),
45    #[error("TLS config error: {0}")]
46    TlsConfig(String),
47}
48
49/// Configuration for a MoQT client connection.
50pub struct ClientConfig {
51    /// MoQT versions the client supports.
52    pub supported_versions: Vec<VarInt>,
53    /// Whether to skip TLS certificate verification (for testing).
54    pub skip_cert_verification: bool,
55}
56
57impl Default for ClientConfig {
58    fn default() -> Self {
59        Self {
60            supported_versions: vec![VarInt::from_u64(0xff000000 + 14).unwrap()], // draft-14
61            skip_cert_verification: false,
62        }
63    }
64}
65
66/// A framed writer for a QUIC send stream. Handles MoQT length-prefixed framing.
67pub struct FramedSendStream {
68    inner: quinn::SendStream,
69}
70
71impl FramedSendStream {
72    pub fn new(inner: quinn::SendStream) -> Self {
73        Self { inner }
74    }
75
76    /// Write a control message to the stream with type+length framing.
77    pub async fn write_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
78        let mut buf = Vec::new();
79        msg.encode(&mut buf)?;
80        self.inner.write_all(&buf).await?;
81        Ok(())
82    }
83
84    /// Write a subgroup stream header.
85    pub async fn write_subgroup_header(
86        &mut self,
87        header: &SubgroupHeader,
88    ) -> Result<(), ConnectionError> {
89        let mut buf = Vec::new();
90        header.encode(&mut buf);
91        self.inner.write_all(&buf).await?;
92        Ok(())
93    }
94
95    /// Write a fetch response header.
96    pub async fn write_fetch_header(
97        &mut self,
98        header: &FetchHeader,
99    ) -> Result<(), ConnectionError> {
100        let mut buf = Vec::new();
101        header.encode(&mut buf);
102        self.inner.write_all(&buf).await?;
103        Ok(())
104    }
105
106    /// Write an object header followed by payload.
107    pub async fn write_object(
108        &mut self,
109        header: &ObjectHeader,
110        payload: &[u8],
111    ) -> Result<(), ConnectionError> {
112        let mut buf = Vec::new();
113        header.encode(&mut buf);
114        self.inner.write_all(&buf).await?;
115        if !payload.is_empty() {
116            self.inner.write_all(payload).await?;
117        }
118        Ok(())
119    }
120
121    /// Finish the stream (send FIN).
122    pub async fn finish(&mut self) -> Result<(), ConnectionError> {
123        self.inner.finish()?;
124        Ok(())
125    }
126}
127
128/// A framed reader for a QUIC recv stream. Handles MoQT varint-length decoding.
129pub struct FramedRecvStream {
130    inner: quinn::RecvStream,
131    buf: BytesMut,
132}
133
134impl FramedRecvStream {
135    pub fn new(inner: quinn::RecvStream) -> Self {
136        Self { inner, buf: BytesMut::with_capacity(4096) }
137    }
138
139    /// Read more data from the stream into the internal buffer.
140    async fn fill(&mut self) -> Result<bool, ConnectionError> {
141        let mut tmp = [0u8; 4096];
142        match self.inner.read(&mut tmp).await {
143            Ok(Some(n)) => {
144                self.buf.extend_from_slice(&tmp[..n]);
145                Ok(true)
146            }
147            Ok(None) => Ok(false),
148            Err(e) => Err(ConnectionError::Read(quinn::ReadExactError::ReadError(e))),
149        }
150    }
151
152    /// Ensure at least `n` bytes are available in the buffer.
153    async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
154        while self.buf.len() < n {
155            if !self.fill().await? {
156                return Err(ConnectionError::UnexpectedEnd);
157            }
158        }
159        Ok(())
160    }
161
162    /// Read a control message from the stream.
163    pub async fn read_control(&mut self) -> Result<ControlMessage, ConnectionError> {
164        // Read type ID varint
165        self.ensure(1).await?;
166        let type_len = varint_len(self.buf[0]);
167        self.ensure(type_len).await?;
168
169        let mut cursor = &self.buf[..type_len];
170        let _type_id = VarInt::decode(&mut cursor)?;
171
172        // Read payload length varint
173        self.ensure(type_len + 1).await?;
174        let payload_len_start = type_len;
175        let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
176        self.ensure(type_len + payload_len_varint_len).await?;
177
178        let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
179        let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
180
181        // Read full payload
182        let total = type_len + payload_len_varint_len + payload_len;
183        self.ensure(total).await?;
184
185        // Now decode the whole message
186        let mut frame = &self.buf[..total];
187        let msg = ControlMessage::decode(&mut frame)?;
188        self.buf.advance(total);
189        Ok(msg)
190    }
191
192    /// Read a subgroup stream header.
193    pub async fn read_subgroup_header(&mut self) -> Result<SubgroupHeader, ConnectionError> {
194        // SubgroupHeader: track_alias(v) + group(v) + subgroup(v) + priority(1)
195        // Need at least 4 bytes, but varints can be up to 8 each
196        self.ensure(4).await?;
197        loop {
198            let mut cursor = &self.buf[..];
199            match SubgroupHeader::decode(&mut cursor) {
200                Ok(header) => {
201                    let consumed = self.buf.len() - cursor.remaining();
202                    self.buf.advance(consumed);
203                    return Ok(header);
204                }
205                Err(_) => {
206                    if !self.fill().await? {
207                        return Err(ConnectionError::UnexpectedEnd);
208                    }
209                }
210            }
211        }
212    }
213
214    /// Read a fetch response header.
215    pub async fn read_fetch_header(&mut self) -> Result<FetchHeader, ConnectionError> {
216        self.ensure(4).await?;
217        loop {
218            let mut cursor = &self.buf[..];
219            match FetchHeader::decode(&mut cursor) {
220                Ok(header) => {
221                    let consumed = self.buf.len() - cursor.remaining();
222                    self.buf.advance(consumed);
223                    return Ok(header);
224                }
225                Err(_) => {
226                    if !self.fill().await? {
227                        return Err(ConnectionError::UnexpectedEnd);
228                    }
229                }
230            }
231        }
232    }
233
234    /// Read an object header.
235    pub async fn read_object_header(&mut self) -> Result<ObjectHeader, ConnectionError> {
236        self.ensure(2).await?;
237        loop {
238            let mut cursor = &self.buf[..];
239            match ObjectHeader::decode(&mut cursor) {
240                Ok(header) => {
241                    let consumed = self.buf.len() - cursor.remaining();
242                    self.buf.advance(consumed);
243                    return Ok(header);
244                }
245                Err(_) => {
246                    if !self.fill().await? {
247                        return Err(ConnectionError::UnexpectedEnd);
248                    }
249                }
250            }
251        }
252    }
253
254    /// Read exactly `n` bytes of object payload.
255    pub async fn read_payload(&mut self, n: usize) -> Result<Vec<u8>, ConnectionError> {
256        self.ensure(n).await?;
257        let data = self.buf[..n].to_vec();
258        self.buf.advance(n);
259        Ok(data)
260    }
261}
262
263/// A live MoQT connection over QUIC, combining the endpoint state machine
264/// with actual network I/O.
265pub struct Connection {
266    quic: quinn::Connection,
267    endpoint: Endpoint,
268    control_send: Option<FramedSendStream>,
269    control_recv: Option<FramedRecvStream>,
270}
271
272impl Connection {
273    /// Connect to a MoQT server as a client.
274    ///
275    /// Opens a QUIC connection, performs the bidirectional control stream
276    /// setup handshake (CLIENT_SETUP / SERVER_SETUP), and returns a
277    /// ready-to-use connection.
278    pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
279        let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
280            ConnectionError::InvalidAddress(e.to_string())
281        })?;
282
283        let mut endpoint = Endpoint::new(Role::Client);
284
285        // Build TLS config
286        let mut tls_config = if config.skip_cert_verification {
287            rustls::ClientConfig::builder()
288                .dangerous()
289                .with_custom_certificate_verifier(Arc::new(SkipVerification))
290                .with_no_client_auth()
291        } else {
292            let mut roots = rustls::RootCertStore::empty();
293            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
294            rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
295        };
296
297        tls_config.alpn_protocols = vec![MOQT_ALPN.to_vec()];
298
299        let quic_config: quinn::crypto::rustls::QuicClientConfig =
300            tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
301        let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
302
303        let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
304            .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
305        quinn_endpoint.set_default_client_config(client_config);
306
307        let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
308
309        let quic = quinn_endpoint.connect(server_addr, &server_name)?.await?;
310
311        // Open bidirectional control stream
312        let (send, recv) = quic.open_bi().await?;
313        let mut control_send = FramedSendStream::new(send);
314        let mut control_recv = FramedRecvStream::new(recv);
315
316        // Perform setup handshake
317        endpoint.connect()?;
318        let setup_msg = endpoint.send_client_setup(config.supported_versions)?;
319        control_send.write_control(&setup_msg).await?;
320
321        let server_setup = control_recv.read_control().await?;
322        if let ControlMessage::ServerSetup(ref ss) = server_setup {
323            endpoint.receive_server_setup(ss)?;
324        } else {
325            return Err(ConnectionError::Endpoint(EndpointError::NotActive));
326        }
327
328        Ok(Self {
329            quic,
330            endpoint,
331            control_send: Some(control_send),
332            control_recv: Some(control_recv),
333        })
334    }
335
336    /// Accept a MoQT connection as a server (given an already-accepted QUIC connection).
337    pub async fn accept(
338        quic: quinn::Connection,
339        selected_version: VarInt,
340    ) -> Result<Self, ConnectionError> {
341        let mut endpoint = Endpoint::new(Role::Server);
342        endpoint.connect()?;
343
344        // Accept the bidirectional control stream from the client
345        let (send, recv) = quic.accept_bi().await?;
346        let mut control_send = FramedSendStream::new(send);
347        let mut control_recv = FramedRecvStream::new(recv);
348
349        // Read CLIENT_SETUP
350        let client_setup_msg = control_recv.read_control().await?;
351        if let ControlMessage::ClientSetup(ref cs) = client_setup_msg {
352            let server_setup = endpoint.receive_client_setup_and_respond(cs, selected_version)?;
353            control_send.write_control(&server_setup).await?;
354        } else {
355            return Err(ConnectionError::Endpoint(EndpointError::NotActive));
356        }
357
358        Ok(Self {
359            quic,
360            endpoint,
361            control_send: Some(control_send),
362            control_recv: Some(control_recv),
363        })
364    }
365
366    // ── Control message I/O ─────────────────────────────────
367
368    /// Send a control message on the control stream.
369    pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
370        let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
371        send.write_control(msg).await
372    }
373
374    /// Read the next control message from the control stream.
375    pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
376        let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
377        recv.read_control().await
378    }
379
380    /// Read and dispatch the next incoming control message through the endpoint
381    /// state machine. Returns the decoded message for inspection.
382    pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
383        let msg = self.recv_control().await?;
384        self.endpoint.receive_message(msg.clone())?;
385        Ok(msg)
386    }
387
388    // ── Subscribe flow ──────────────────────────────────────
389
390    /// Send a SUBSCRIBE and return the allocated request ID.
391    pub async fn subscribe(
392        &mut self,
393        track_namespace: TrackNamespace,
394        track_name: Vec<u8>,
395        subscriber_priority: u8,
396        group_order: GroupOrder,
397        filter_type: FilterType,
398    ) -> Result<VarInt, ConnectionError> {
399        let (req_id, msg) = self.endpoint.subscribe(
400            track_namespace,
401            track_name,
402            subscriber_priority,
403            group_order,
404            filter_type,
405        )?;
406        self.send_control(&msg).await?;
407        Ok(req_id)
408    }
409
410    /// Send an UNSUBSCRIBE for the given request ID.
411    pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
412        let msg = self.endpoint.unsubscribe(request_id)?;
413        self.send_control(&msg).await
414    }
415
416    // ── Fetch flow ──────────────────────────────────────────
417
418    /// Send a FETCH and return the allocated request ID.
419    pub async fn fetch(
420        &mut self,
421        track_namespace: TrackNamespace,
422        track_name: Vec<u8>,
423        start_group: VarInt,
424        start_object: VarInt,
425    ) -> Result<VarInt, ConnectionError> {
426        let (req_id, msg) =
427            self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
428        self.send_control(&msg).await?;
429        Ok(req_id)
430    }
431
432    /// Send a FETCH_CANCEL for the given request ID.
433    pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
434        let msg = self.endpoint.fetch_cancel(request_id)?;
435        self.send_control(&msg).await
436    }
437
438    // ── Namespace flows ─────────────────────────────────────
439
440    /// Send a SUBSCRIBE_NAMESPACE and return the request ID.
441    pub async fn subscribe_namespace(
442        &mut self,
443        track_namespace: TrackNamespace,
444    ) -> Result<VarInt, ConnectionError> {
445        let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
446        self.send_control(&msg).await?;
447        Ok(req_id)
448    }
449
450    /// Send a PUBLISH_NAMESPACE and return the request ID.
451    pub async fn publish_namespace(
452        &mut self,
453        track_namespace: TrackNamespace,
454    ) -> Result<VarInt, ConnectionError> {
455        let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
456        self.send_control(&msg).await?;
457        Ok(req_id)
458    }
459
460    // ── Data streams ────────────────────────────────────────
461
462    /// Open a new unidirectional stream for sending subgroup data.
463    pub async fn open_subgroup_stream(
464        &self,
465        header: &SubgroupHeader,
466    ) -> Result<FramedSendStream, ConnectionError> {
467        let send = self.quic.open_uni().await?;
468        let mut framed = FramedSendStream::new(send);
469        framed.write_subgroup_header(header).await?;
470        Ok(framed)
471    }
472
473    /// Accept an incoming unidirectional data stream and read its subgroup header.
474    pub async fn accept_subgroup_stream(
475        &self,
476    ) -> Result<(SubgroupHeader, FramedRecvStream), ConnectionError> {
477        let recv = self.quic.accept_uni().await?;
478        let mut framed = FramedRecvStream::new(recv);
479        let header = framed.read_subgroup_header().await?;
480        Ok((header, framed))
481    }
482
483    /// Send an object via datagram.
484    pub fn send_datagram(
485        &self,
486        header: &DatagramHeader,
487        payload: &[u8],
488    ) -> Result<(), ConnectionError> {
489        let mut buf = Vec::new();
490        header.encode(&mut buf);
491        buf.extend_from_slice(payload);
492        self.quic.send_datagram(bytes::Bytes::from(buf))?;
493        Ok(())
494    }
495
496    /// Receive a datagram and decode its header.
497    pub async fn recv_datagram(&self) -> Result<(DatagramHeader, Vec<u8>), ConnectionError> {
498        let data = self.quic.read_datagram().await?;
499        let mut cursor = &data[..];
500        let header = DatagramHeader::decode(&mut cursor)?;
501        let payload = cursor.to_vec();
502        Ok((header, payload))
503    }
504
505    // ── Accessors ───────────────────────────────────────────
506
507    /// Access the underlying endpoint state machine.
508    pub fn endpoint(&self) -> &Endpoint {
509        &self.endpoint
510    }
511
512    /// Mutable access to the endpoint state machine.
513    pub fn endpoint_mut(&mut self) -> &mut Endpoint {
514        &mut self.endpoint
515    }
516
517    /// Get the negotiated MoQT version.
518    pub fn negotiated_version(&self) -> Option<VarInt> {
519        self.endpoint.negotiated_version()
520    }
521
522    /// Close the connection.
523    pub fn close(&self, code: u32, reason: &[u8]) {
524        self.quic.close(quinn::VarInt::from_u32(code), reason);
525    }
526}
527
528/// Determine the encoded length of a varint from its first byte.
529fn varint_len(first_byte: u8) -> usize {
530    1 << (first_byte >> 6)
531}
532
533/// TLS certificate verifier that skips all verification (for testing only).
534#[derive(Debug)]
535struct SkipVerification;
536
537impl rustls::client::danger::ServerCertVerifier for SkipVerification {
538    fn verify_server_cert(
539        &self,
540        _end_entity: &rustls::pki_types::CertificateDer<'_>,
541        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
542        _server_name: &rustls::pki_types::ServerName<'_>,
543        _ocsp_response: &[u8],
544        _now: rustls::pki_types::UnixTime,
545    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
546        Ok(rustls::client::danger::ServerCertVerified::assertion())
547    }
548
549    fn verify_tls12_signature(
550        &self,
551        _message: &[u8],
552        _cert: &rustls::pki_types::CertificateDer<'_>,
553        _dcs: &rustls::DigitallySignedStruct,
554    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
555        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
556    }
557
558    fn verify_tls13_signature(
559        &self,
560        _message: &[u8],
561        _cert: &rustls::pki_types::CertificateDer<'_>,
562        _dcs: &rustls::DigitallySignedStruct,
563    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
564        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
565    }
566
567    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
568        vec![
569            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
570            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
571            rustls::SignatureScheme::ED25519,
572            rustls::SignatureScheme::RSA_PSS_SHA256,
573            rustls::SignatureScheme::RSA_PSS_SHA384,
574            rustls::SignatureScheme::RSA_PSS_SHA512,
575        ]
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn varint_len_single_byte() {
585        // 0b00xxxxxx -> 1 byte
586        assert_eq!(varint_len(0x00), 1);
587        assert_eq!(varint_len(0x3F), 1);
588    }
589
590    #[test]
591    fn varint_len_two_bytes() {
592        // 0b01xxxxxx -> 2 bytes
593        assert_eq!(varint_len(0x40), 2);
594        assert_eq!(varint_len(0x7F), 2);
595    }
596
597    #[test]
598    fn varint_len_four_bytes() {
599        // 0b10xxxxxx -> 4 bytes
600        assert_eq!(varint_len(0x80), 4);
601        assert_eq!(varint_len(0xBF), 4);
602    }
603
604    #[test]
605    fn varint_len_eight_bytes() {
606        // 0b11xxxxxx -> 8 bytes
607        assert_eq!(varint_len(0xC0), 8);
608        assert_eq!(varint_len(0xFF), 8);
609    }
610
611    #[test]
612    fn default_client_config() {
613        let config = ClientConfig::default();
614        assert_eq!(config.supported_versions.len(), 1);
615        assert!(!config.skip_cert_verification);
616    }
617
618    #[test]
619    fn moqt_alpn_value() {
620        assert_eq!(MOQT_ALPN, b"moq-00");
621    }
622}