1use std::sync::Arc;
2
3use bytes::{Buf, BytesMut};
4
5use moqtap_codec::data_stream::{DatagramHeader, FetchHeader, ObjectHeader, SubgroupHeader};
6use moqtap_codec::message::{CodecError, ControlMessage};
7use moqtap_codec::types::*;
8use moqtap_codec::varint::VarInt;
9use moqtap_session::request_id::Role;
10
11use crate::endpoint::{Endpoint, EndpointError};
12
13pub const MOQT_ALPN: &[u8] = b"moq-00";
15
16#[derive(Debug, thiserror::Error)]
18pub enum ConnectionError {
19 #[error("endpoint error: {0}")]
20 Endpoint(#[from] EndpointError),
21 #[error("codec error: {0}")]
22 Codec(#[from] CodecError),
23 #[error("quinn connection error: {0}")]
24 Connection(#[from] quinn::ConnectionError),
25 #[error("quinn write error: {0}")]
26 Write(#[from] quinn::WriteError),
27 #[error("quinn read error: {0}")]
28 Read(#[from] quinn::ReadExactError),
29 #[error("quinn connect error: {0}")]
30 Connect(#[from] quinn::ConnectError),
31 #[error("closed stream: {0}")]
32 ClosedStream(#[from] quinn::ClosedStream),
33 #[error("send datagram error: {0}")]
34 SendDatagram(#[from] quinn::SendDatagramError),
35 #[error("varint error: {0}")]
36 VarInt(#[from] moqtap_codec::varint::VarIntError),
37 #[error("control stream not open")]
38 NoControlStream,
39 #[error("unexpected end of stream")]
40 UnexpectedEnd,
41 #[error("stream finished")]
42 StreamFinished,
43 #[error("invalid server address: {0}")]
44 InvalidAddress(String),
45 #[error("TLS config error: {0}")]
46 TlsConfig(String),
47}
48
49pub struct ClientConfig {
51 pub supported_versions: Vec<VarInt>,
53 pub skip_cert_verification: bool,
55}
56
57impl Default for ClientConfig {
58 fn default() -> Self {
59 Self {
60 supported_versions: vec![VarInt::from_u64(0xff000000 + 14).unwrap()], skip_cert_verification: false,
62 }
63 }
64}
65
66pub struct FramedSendStream {
68 inner: quinn::SendStream,
69}
70
71impl FramedSendStream {
72 pub fn new(inner: quinn::SendStream) -> Self {
73 Self { inner }
74 }
75
76 pub async fn write_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
78 let mut buf = Vec::new();
79 msg.encode(&mut buf)?;
80 self.inner.write_all(&buf).await?;
81 Ok(())
82 }
83
84 pub async fn write_subgroup_header(
86 &mut self,
87 header: &SubgroupHeader,
88 ) -> Result<(), ConnectionError> {
89 let mut buf = Vec::new();
90 header.encode(&mut buf);
91 self.inner.write_all(&buf).await?;
92 Ok(())
93 }
94
95 pub async fn write_fetch_header(
97 &mut self,
98 header: &FetchHeader,
99 ) -> Result<(), ConnectionError> {
100 let mut buf = Vec::new();
101 header.encode(&mut buf);
102 self.inner.write_all(&buf).await?;
103 Ok(())
104 }
105
106 pub async fn write_object(
108 &mut self,
109 header: &ObjectHeader,
110 payload: &[u8],
111 ) -> Result<(), ConnectionError> {
112 let mut buf = Vec::new();
113 header.encode(&mut buf);
114 self.inner.write_all(&buf).await?;
115 if !payload.is_empty() {
116 self.inner.write_all(payload).await?;
117 }
118 Ok(())
119 }
120
121 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
123 self.inner.finish()?;
124 Ok(())
125 }
126}
127
128pub struct FramedRecvStream {
130 inner: quinn::RecvStream,
131 buf: BytesMut,
132}
133
134impl FramedRecvStream {
135 pub fn new(inner: quinn::RecvStream) -> Self {
136 Self { inner, buf: BytesMut::with_capacity(4096) }
137 }
138
139 async fn fill(&mut self) -> Result<bool, ConnectionError> {
141 let mut tmp = [0u8; 4096];
142 match self.inner.read(&mut tmp).await {
143 Ok(Some(n)) => {
144 self.buf.extend_from_slice(&tmp[..n]);
145 Ok(true)
146 }
147 Ok(None) => Ok(false),
148 Err(e) => Err(ConnectionError::Read(quinn::ReadExactError::ReadError(e))),
149 }
150 }
151
152 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
154 while self.buf.len() < n {
155 if !self.fill().await? {
156 return Err(ConnectionError::UnexpectedEnd);
157 }
158 }
159 Ok(())
160 }
161
162 pub async fn read_control(&mut self) -> Result<ControlMessage, ConnectionError> {
164 self.ensure(1).await?;
166 let type_len = varint_len(self.buf[0]);
167 self.ensure(type_len).await?;
168
169 let mut cursor = &self.buf[..type_len];
170 let _type_id = VarInt::decode(&mut cursor)?;
171
172 self.ensure(type_len + 1).await?;
174 let payload_len_start = type_len;
175 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
176 self.ensure(type_len + payload_len_varint_len).await?;
177
178 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
179 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
180
181 let total = type_len + payload_len_varint_len + payload_len;
183 self.ensure(total).await?;
184
185 let mut frame = &self.buf[..total];
187 let msg = ControlMessage::decode(&mut frame)?;
188 self.buf.advance(total);
189 Ok(msg)
190 }
191
192 pub async fn read_subgroup_header(&mut self) -> Result<SubgroupHeader, ConnectionError> {
194 self.ensure(4).await?;
197 loop {
198 let mut cursor = &self.buf[..];
199 match SubgroupHeader::decode(&mut cursor) {
200 Ok(header) => {
201 let consumed = self.buf.len() - cursor.remaining();
202 self.buf.advance(consumed);
203 return Ok(header);
204 }
205 Err(_) => {
206 if !self.fill().await? {
207 return Err(ConnectionError::UnexpectedEnd);
208 }
209 }
210 }
211 }
212 }
213
214 pub async fn read_fetch_header(&mut self) -> Result<FetchHeader, ConnectionError> {
216 self.ensure(4).await?;
217 loop {
218 let mut cursor = &self.buf[..];
219 match FetchHeader::decode(&mut cursor) {
220 Ok(header) => {
221 let consumed = self.buf.len() - cursor.remaining();
222 self.buf.advance(consumed);
223 return Ok(header);
224 }
225 Err(_) => {
226 if !self.fill().await? {
227 return Err(ConnectionError::UnexpectedEnd);
228 }
229 }
230 }
231 }
232 }
233
234 pub async fn read_object_header(&mut self) -> Result<ObjectHeader, ConnectionError> {
236 self.ensure(2).await?;
237 loop {
238 let mut cursor = &self.buf[..];
239 match ObjectHeader::decode(&mut cursor) {
240 Ok(header) => {
241 let consumed = self.buf.len() - cursor.remaining();
242 self.buf.advance(consumed);
243 return Ok(header);
244 }
245 Err(_) => {
246 if !self.fill().await? {
247 return Err(ConnectionError::UnexpectedEnd);
248 }
249 }
250 }
251 }
252 }
253
254 pub async fn read_payload(&mut self, n: usize) -> Result<Vec<u8>, ConnectionError> {
256 self.ensure(n).await?;
257 let data = self.buf[..n].to_vec();
258 self.buf.advance(n);
259 Ok(data)
260 }
261}
262
263pub struct Connection {
266 quic: quinn::Connection,
267 endpoint: Endpoint,
268 control_send: Option<FramedSendStream>,
269 control_recv: Option<FramedRecvStream>,
270}
271
272impl Connection {
273 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
279 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
280 ConnectionError::InvalidAddress(e.to_string())
281 })?;
282
283 let mut endpoint = Endpoint::new(Role::Client);
284
285 let mut tls_config = if config.skip_cert_verification {
287 rustls::ClientConfig::builder()
288 .dangerous()
289 .with_custom_certificate_verifier(Arc::new(SkipVerification))
290 .with_no_client_auth()
291 } else {
292 let mut roots = rustls::RootCertStore::empty();
293 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
294 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
295 };
296
297 tls_config.alpn_protocols = vec![MOQT_ALPN.to_vec()];
298
299 let quic_config: quinn::crypto::rustls::QuicClientConfig =
300 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
301 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
302
303 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
304 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
305 quinn_endpoint.set_default_client_config(client_config);
306
307 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
308
309 let quic = quinn_endpoint.connect(server_addr, &server_name)?.await?;
310
311 let (send, recv) = quic.open_bi().await?;
313 let mut control_send = FramedSendStream::new(send);
314 let mut control_recv = FramedRecvStream::new(recv);
315
316 endpoint.connect()?;
318 let setup_msg = endpoint.send_client_setup(config.supported_versions)?;
319 control_send.write_control(&setup_msg).await?;
320
321 let server_setup = control_recv.read_control().await?;
322 if let ControlMessage::ServerSetup(ref ss) = server_setup {
323 endpoint.receive_server_setup(ss)?;
324 } else {
325 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
326 }
327
328 Ok(Self {
329 quic,
330 endpoint,
331 control_send: Some(control_send),
332 control_recv: Some(control_recv),
333 })
334 }
335
336 pub async fn accept(
338 quic: quinn::Connection,
339 selected_version: VarInt,
340 ) -> Result<Self, ConnectionError> {
341 let mut endpoint = Endpoint::new(Role::Server);
342 endpoint.connect()?;
343
344 let (send, recv) = quic.accept_bi().await?;
346 let mut control_send = FramedSendStream::new(send);
347 let mut control_recv = FramedRecvStream::new(recv);
348
349 let client_setup_msg = control_recv.read_control().await?;
351 if let ControlMessage::ClientSetup(ref cs) = client_setup_msg {
352 let server_setup = endpoint.receive_client_setup_and_respond(cs, selected_version)?;
353 control_send.write_control(&server_setup).await?;
354 } else {
355 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
356 }
357
358 Ok(Self {
359 quic,
360 endpoint,
361 control_send: Some(control_send),
362 control_recv: Some(control_recv),
363 })
364 }
365
366 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
370 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
371 send.write_control(msg).await
372 }
373
374 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
376 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
377 recv.read_control().await
378 }
379
380 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
383 let msg = self.recv_control().await?;
384 self.endpoint.receive_message(msg.clone())?;
385 Ok(msg)
386 }
387
388 pub async fn subscribe(
392 &mut self,
393 track_namespace: TrackNamespace,
394 track_name: Vec<u8>,
395 subscriber_priority: u8,
396 group_order: GroupOrder,
397 filter_type: FilterType,
398 ) -> Result<VarInt, ConnectionError> {
399 let (req_id, msg) = self.endpoint.subscribe(
400 track_namespace,
401 track_name,
402 subscriber_priority,
403 group_order,
404 filter_type,
405 )?;
406 self.send_control(&msg).await?;
407 Ok(req_id)
408 }
409
410 pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
412 let msg = self.endpoint.unsubscribe(request_id)?;
413 self.send_control(&msg).await
414 }
415
416 pub async fn fetch(
420 &mut self,
421 track_namespace: TrackNamespace,
422 track_name: Vec<u8>,
423 start_group: VarInt,
424 start_object: VarInt,
425 ) -> Result<VarInt, ConnectionError> {
426 let (req_id, msg) =
427 self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
428 self.send_control(&msg).await?;
429 Ok(req_id)
430 }
431
432 pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
434 let msg = self.endpoint.fetch_cancel(request_id)?;
435 self.send_control(&msg).await
436 }
437
438 pub async fn subscribe_namespace(
442 &mut self,
443 track_namespace: TrackNamespace,
444 ) -> Result<VarInt, ConnectionError> {
445 let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
446 self.send_control(&msg).await?;
447 Ok(req_id)
448 }
449
450 pub async fn publish_namespace(
452 &mut self,
453 track_namespace: TrackNamespace,
454 ) -> Result<VarInt, ConnectionError> {
455 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
456 self.send_control(&msg).await?;
457 Ok(req_id)
458 }
459
460 pub async fn open_subgroup_stream(
464 &self,
465 header: &SubgroupHeader,
466 ) -> Result<FramedSendStream, ConnectionError> {
467 let send = self.quic.open_uni().await?;
468 let mut framed = FramedSendStream::new(send);
469 framed.write_subgroup_header(header).await?;
470 Ok(framed)
471 }
472
473 pub async fn accept_subgroup_stream(
475 &self,
476 ) -> Result<(SubgroupHeader, FramedRecvStream), ConnectionError> {
477 let recv = self.quic.accept_uni().await?;
478 let mut framed = FramedRecvStream::new(recv);
479 let header = framed.read_subgroup_header().await?;
480 Ok((header, framed))
481 }
482
483 pub fn send_datagram(
485 &self,
486 header: &DatagramHeader,
487 payload: &[u8],
488 ) -> Result<(), ConnectionError> {
489 let mut buf = Vec::new();
490 header.encode(&mut buf);
491 buf.extend_from_slice(payload);
492 self.quic.send_datagram(bytes::Bytes::from(buf))?;
493 Ok(())
494 }
495
496 pub async fn recv_datagram(&self) -> Result<(DatagramHeader, Vec<u8>), ConnectionError> {
498 let data = self.quic.read_datagram().await?;
499 let mut cursor = &data[..];
500 let header = DatagramHeader::decode(&mut cursor)?;
501 let payload = cursor.to_vec();
502 Ok((header, payload))
503 }
504
505 pub fn endpoint(&self) -> &Endpoint {
509 &self.endpoint
510 }
511
512 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
514 &mut self.endpoint
515 }
516
517 pub fn negotiated_version(&self) -> Option<VarInt> {
519 self.endpoint.negotiated_version()
520 }
521
522 pub fn close(&self, code: u32, reason: &[u8]) {
524 self.quic.close(quinn::VarInt::from_u32(code), reason);
525 }
526}
527
528fn varint_len(first_byte: u8) -> usize {
530 1 << (first_byte >> 6)
531}
532
533#[derive(Debug)]
535struct SkipVerification;
536
537impl rustls::client::danger::ServerCertVerifier for SkipVerification {
538 fn verify_server_cert(
539 &self,
540 _end_entity: &rustls::pki_types::CertificateDer<'_>,
541 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
542 _server_name: &rustls::pki_types::ServerName<'_>,
543 _ocsp_response: &[u8],
544 _now: rustls::pki_types::UnixTime,
545 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
546 Ok(rustls::client::danger::ServerCertVerified::assertion())
547 }
548
549 fn verify_tls12_signature(
550 &self,
551 _message: &[u8],
552 _cert: &rustls::pki_types::CertificateDer<'_>,
553 _dcs: &rustls::DigitallySignedStruct,
554 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
555 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
556 }
557
558 fn verify_tls13_signature(
559 &self,
560 _message: &[u8],
561 _cert: &rustls::pki_types::CertificateDer<'_>,
562 _dcs: &rustls::DigitallySignedStruct,
563 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
564 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
565 }
566
567 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
568 vec![
569 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
570 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
571 rustls::SignatureScheme::ED25519,
572 rustls::SignatureScheme::RSA_PSS_SHA256,
573 rustls::SignatureScheme::RSA_PSS_SHA384,
574 rustls::SignatureScheme::RSA_PSS_SHA512,
575 ]
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn varint_len_single_byte() {
585 assert_eq!(varint_len(0x00), 1);
587 assert_eq!(varint_len(0x3F), 1);
588 }
589
590 #[test]
591 fn varint_len_two_bytes() {
592 assert_eq!(varint_len(0x40), 2);
594 assert_eq!(varint_len(0x7F), 2);
595 }
596
597 #[test]
598 fn varint_len_four_bytes() {
599 assert_eq!(varint_len(0x80), 4);
601 assert_eq!(varint_len(0xBF), 4);
602 }
603
604 #[test]
605 fn varint_len_eight_bytes() {
606 assert_eq!(varint_len(0xC0), 8);
608 assert_eq!(varint_len(0xFF), 8);
609 }
610
611 #[test]
612 fn default_client_config() {
613 let config = ClientConfig::default();
614 assert_eq!(config.supported_versions.len(), 1);
615 assert!(!config.skip_cert_verification);
616 }
617
618 #[test]
619 fn moqt_alpn_value() {
620 assert_eq!(MOQT_ALPN, b"moq-00");
621 }
622}