1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft16::endpoint::{Endpoint, EndpointError};
6use crate::draft16::event::{ClientEvent, Direction, StreamKind};
7use crate::draft16::observer::ConnectionObserver;
8use crate::draft16::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12 AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft16::data_stream::{FetchHeader, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft16::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::kvp::KeyValuePair;
18use moqtap_codec::types::*;
19use moqtap_codec::varint::VarInt;
20use moqtap_codec::version::DraftVersion;
21
22pub const MOQT_ALPN: &[u8] = b"moq-00";
24
25#[derive(Debug, thiserror::Error)]
27pub enum ConnectionError {
28 #[error("endpoint error: {0}")]
30 Endpoint(#[from] EndpointError),
31 #[error("codec error: {0}")]
33 Codec(#[from] CodecError),
34 #[error("transport error: {0}")]
36 Transport(#[from] TransportError),
37 #[error("varint error: {0}")]
39 VarInt(#[from] moqtap_codec::varint::VarIntError),
40 #[error("control stream not open")]
42 NoControlStream,
43 #[error("unexpected end of stream")]
45 UnexpectedEnd,
46 #[error("stream finished")]
48 StreamFinished,
49 #[error("invalid server address: {0}")]
51 InvalidAddress(String),
52 #[error("TLS config error: {0}")]
54 TlsConfig(String),
55 #[error("data stream state error: {0}")]
57 DataStreamState(&'static str),
58}
59
60#[derive(Debug, Clone)]
62pub enum TransportType {
63 Quic,
65 WebTransport {
67 url: String,
69 },
70}
71
72pub struct ClientConfig {
76 pub draft: DraftVersion,
78 pub transport: TransportType,
80 pub skip_cert_verification: bool,
82 pub ca_certs: Vec<Vec<u8>>,
84 pub setup_parameters: Vec<KeyValuePair>,
86}
87
88impl ClientConfig {
89 pub fn alpn(&self) -> Vec<Vec<u8>> {
91 match &self.transport {
92 TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
93 TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
94 }
95 }
96}
97
98pub struct FramedSendStream {
100 inner: SendStream,
101 draft: DraftVersion,
102 subgroup_io: Option<SubgroupObjectReader>,
104}
105
106impl FramedSendStream {
107 pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
109 Self { inner, draft, subgroup_io: None }
110 }
111
112 pub fn stream_id(&self) -> u64 {
114 self.inner.stream_id()
115 }
116
117 pub async fn write_control(
120 &mut self,
121 msg: &AnyControlMessage,
122 ) -> Result<Vec<u8>, ConnectionError> {
123 let mut buf = Vec::new();
124 msg.encode(&mut buf)?;
125 self.inner.write_all(&buf).await?;
126 Ok(buf)
127 }
128
129 pub async fn write_subgroup_header(
133 &mut self,
134 header: &AnySubgroupHeader,
135 ) -> Result<(), ConnectionError> {
136 let mut buf = Vec::new();
137 header.encode(&mut buf);
138 self.inner.write_all(&buf).await?;
139 if let AnySubgroupHeader::Draft16(ref d16) = header {
140 self.subgroup_io = Some(SubgroupObjectReader::new(d16));
141 }
142 Ok(())
143 }
144
145 pub async fn write_fetch_header(
147 &mut self,
148 header: &AnyFetchHeader,
149 ) -> Result<(), ConnectionError> {
150 let mut buf = Vec::new();
151 header.encode(&mut buf);
152 self.inner.write_all(&buf).await?;
153 Ok(())
154 }
155
156 pub async fn write_subgroup_object(
160 &mut self,
161 object: &SubgroupObject,
162 ) -> Result<(), ConnectionError> {
163 let writer = self
164 .subgroup_io
165 .as_mut()
166 .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
167 let mut buf = Vec::new();
168 writer.write_object(object, &mut buf)?;
169 self.inner.write_all(&buf).await?;
170 Ok(())
171 }
172
173 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
175 self.inner.finish()?;
176 Ok(())
177 }
178
179 pub fn draft(&self) -> DraftVersion {
181 self.draft
182 }
183}
184
185pub struct FramedRecvStream {
187 inner: RecvStream,
188 buf: BytesMut,
189 draft: DraftVersion,
190 subgroup_io: Option<SubgroupObjectReader>,
192}
193
194impl FramedRecvStream {
195 pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
197 Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
198 }
199
200 pub fn stream_id(&self) -> u64 {
202 self.inner.stream_id()
203 }
204
205 async fn fill(&mut self) -> Result<bool, ConnectionError> {
207 let mut tmp = [0u8; 4096];
208 match self.inner.read(&mut tmp).await {
209 Ok(Some(n)) => {
210 self.buf.extend_from_slice(&tmp[..n]);
211 Ok(true)
212 }
213 Ok(None) => Ok(false),
214 Err(e) => Err(ConnectionError::Transport(e)),
215 }
216 }
217
218 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
220 while self.buf.len() < n {
221 if !self.fill().await? {
222 return Err(ConnectionError::UnexpectedEnd);
223 }
224 }
225 Ok(())
226 }
227
228 pub async fn read_control(
234 &mut self,
235 capture_raw: bool,
236 ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
237 self.ensure(1).await?;
239 let type_len = varint_len(self.buf[0]);
240 self.ensure(type_len).await?;
241
242 let mut cursor = &self.buf[..type_len];
243 let _type_id = VarInt::decode(&mut cursor)?;
244
245 let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
247 self.ensure(type_len + 2).await?;
248 let hi = self.buf[type_len] as usize;
249 let lo = self.buf[type_len + 1] as usize;
250 ((hi << 8) | lo, 2)
251 } else {
252 self.ensure(type_len + 1).await?;
253 let payload_len_start = type_len;
254 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
255 self.ensure(type_len + payload_len_varint_len).await?;
256 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
257 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
258 (payload_len, payload_len_varint_len)
259 };
260
261 let total = type_len + len_field_size + payload_len;
263 self.ensure(total).await?;
264
265 let raw = capture_raw.then(|| self.buf[..total].to_vec());
267
268 let mut frame = &self.buf[..total];
270 let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
271 self.buf.advance(total);
272 Ok((msg, raw))
273 }
274
275 pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
278 self.ensure(1).await?;
279 loop {
280 let mut cursor = &self.buf[..];
281 match AnySubgroupHeader::decode(self.draft, &mut cursor) {
282 Ok(header) => {
283 let consumed = self.buf.len() - cursor.remaining();
284 self.buf.advance(consumed);
285 if let AnySubgroupHeader::Draft16(ref d16) = header {
286 self.subgroup_io = Some(SubgroupObjectReader::new(d16));
287 }
288 return Ok(header);
289 }
290 Err(CodecError::UnexpectedEnd) => {
291 if !self.fill().await? {
292 return Err(ConnectionError::UnexpectedEnd);
293 }
294 }
295 Err(e) => return Err(ConnectionError::Codec(e)),
296 }
297 }
298 }
299
300 pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
302 self.ensure(1).await?;
303 loop {
304 let mut cursor = &self.buf[..];
305 match AnyFetchHeader::decode(self.draft, &mut cursor) {
306 Ok(header) => {
307 let consumed = self.buf.len() - cursor.remaining();
308 self.buf.advance(consumed);
309 return Ok(header);
310 }
311 Err(CodecError::UnexpectedEnd) => {
312 if !self.fill().await? {
313 return Err(ConnectionError::UnexpectedEnd);
314 }
315 }
316 Err(e) => return Err(ConnectionError::Codec(e)),
317 }
318 }
319 }
320
321 pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
325 if self.subgroup_io.is_none() {
326 return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
327 }
328 loop {
329 let reader = self.subgroup_io.as_mut().unwrap();
330 let mut probe = reader.clone();
331 let mut cursor = &self.buf[..];
332 match probe.read_object(&mut cursor) {
333 Ok(obj) => {
334 let consumed = self.buf.len() - cursor.remaining();
335 self.buf.advance(consumed);
336 *reader = probe;
337 return Ok(obj);
338 }
339 Err(CodecError::UnexpectedEnd) => {
340 if !self.fill().await? {
341 return Err(ConnectionError::UnexpectedEnd);
342 }
343 }
344 Err(e) => return Err(ConnectionError::Codec(e)),
345 }
346 }
347 }
348
349 pub async fn read_fetch_stream_header(&mut self) -> Result<FetchHeader, ConnectionError> {
351 loop {
352 let mut cursor = &self.buf[..];
353 match FetchHeader::decode(&mut cursor) {
354 Ok(hdr) => {
355 let consumed = self.buf.len() - cursor.remaining();
356 self.buf.advance(consumed);
357 return Ok(hdr);
358 }
359 Err(CodecError::UnexpectedEnd) => {
360 if !self.fill().await? {
361 return Err(ConnectionError::UnexpectedEnd);
362 }
363 }
364 Err(e) => return Err(ConnectionError::Codec(e)),
365 }
366 }
367 }
368
369 pub fn draft(&self) -> DraftVersion {
371 self.draft
372 }
373}
374
375pub struct Connection {
378 transport: Transport,
379 endpoint: Endpoint,
380 draft: DraftVersion,
381 control_send: Option<FramedSendStream>,
382 control_recv: Option<FramedRecvStream>,
383 observer: Option<Box<dyn ConnectionObserver>>,
384}
385
386impl Connection {
387 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
394 let draft = config.draft;
395 let transport = match &config.transport {
396 TransportType::Quic => Self::connect_quic(addr, &config).await?,
397 TransportType::WebTransport { url } => {
398 let url = url.clone();
399 Self::connect_webtransport(&url, &config).await?
400 }
401 };
402
403 let (send, recv) = transport.open_bi().await?;
405 let mut control_send = FramedSendStream::new(send, draft);
406 let mut control_recv = FramedRecvStream::new(recv, draft);
407
408 let mut endpoint = Endpoint::new(Role::Client);
410 endpoint.connect()?;
411 let setup_msg = endpoint.send_client_setup(config.setup_parameters.clone())?;
412 let any_setup = AnyControlMessage::Draft16(setup_msg);
413 let _raw_setup = control_send.write_control(&any_setup).await?;
414
415 let (server_setup, _raw_server_setup) = control_recv.read_control(false).await?;
416 match &server_setup {
418 AnyControlMessage::Draft16(ControlMessage::ServerSetup(ref ss)) => {
419 endpoint.receive_server_setup(ss)?;
420 }
421 _ => {
422 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
423 }
424 }
425
426 let conn = Self {
427 transport,
428 endpoint,
429 draft,
430 control_send: Some(control_send),
431 control_recv: Some(control_recv),
432 observer: None,
433 };
434
435 conn.emit(ClientEvent::SetupComplete { negotiated_version: 0xff000000 + 15 });
437
438 Ok(conn)
439 }
440
441 async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
443 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
444 ConnectionError::InvalidAddress(e.to_string())
445 })?;
446
447 let mut tls_config = if config.skip_cert_verification {
449 rustls::ClientConfig::builder()
450 .dangerous()
451 .with_custom_certificate_verifier(Arc::new(SkipVerification))
452 .with_no_client_auth()
453 } else {
454 let mut roots = rustls::RootCertStore::empty();
455 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
456 for der in &config.ca_certs {
457 roots
458 .add(rustls::pki_types::CertificateDer::from(der.clone()))
459 .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
460 }
461 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
462 };
463
464 tls_config.alpn_protocols = config.alpn();
465
466 let quic_config: quinn::crypto::rustls::QuicClientConfig =
467 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
468 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
469
470 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
471 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
472 quinn_endpoint.set_default_client_config(client_config);
473
474 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
475
476 let quic = quinn_endpoint
477 .connect(server_addr, &server_name)
478 .map_err(TransportError::from)?
479 .await
480 .map_err(TransportError::from)?;
481
482 Ok(Transport::Quic(QuicTransport::new(quic)))
483 }
484
485 #[cfg(feature = "webtransport")]
487 async fn connect_webtransport(
488 url: &str,
489 config: &ClientConfig,
490 ) -> Result<Transport, ConnectionError> {
491 use crate::transport::webtransport::WebTransportTransport;
492
493 let wt_config = if config.skip_cert_verification {
494 wtransport::ClientConfig::builder()
495 .with_bind_default()
496 .with_no_cert_validation()
497 .build()
498 } else {
499 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
500 };
501
502 let endpoint = wtransport::Endpoint::client(wt_config)
503 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
504
505 let connection = endpoint
506 .connect(url)
507 .await
508 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
509
510 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
511 }
512
513 #[cfg(not(feature = "webtransport"))]
515 async fn connect_webtransport(
516 _url: &str,
517 _config: &ClientConfig,
518 ) -> Result<Transport, ConnectionError> {
519 Err(ConnectionError::Transport(TransportError::Connect(
520 "webtransport feature not enabled".into(),
521 )))
522 }
523
524 pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
528 self.observer = Some(observer);
529 }
530
531 pub fn clear_observer(&mut self) {
533 self.observer = None;
534 }
535
536 fn emit(&self, event: ClientEvent) {
538 if let Some(ref obs) = self.observer {
539 obs.on_event_owned(event);
540 }
541 }
542
543 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
550 let any = AnyControlMessage::Draft16(msg.clone());
551 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
552 let raw = send.write_control(&any).await?;
553 self.emit(ClientEvent::ControlMessage {
554 direction: Direction::Send,
555 message: any,
556 raw: Some(raw),
557 });
558 Ok(())
559 }
560
561 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
566 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
567 let capture_raw = self.observer.is_some();
568 let (any, raw) = recv.read_control(capture_raw).await?;
569 if capture_raw {
570 self.emit(ClientEvent::ControlMessage {
571 direction: Direction::Receive,
572 message: any.clone(),
573 raw,
574 });
575 }
576 match any {
578 AnyControlMessage::Draft16(msg) => Ok(msg),
579 _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
580 }
581 }
582
583 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
586 let msg = self.recv_control().await?;
587 self.endpoint.receive_message(msg.clone())?;
588
589 if let ControlMessage::GoAway(ref ga) = msg {
591 self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
592 }
593
594 Ok(msg)
595 }
596
597 pub async fn subscribe(
601 &mut self,
602 track_namespace: TrackNamespace,
603 track_name: Vec<u8>,
604 parameters: Vec<KeyValuePair>,
605 ) -> Result<VarInt, ConnectionError> {
606 let (req_id, msg) = self.endpoint.subscribe(track_namespace, track_name, parameters)?;
607 self.send_control(&msg).await?;
608 Ok(req_id)
609 }
610
611 pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
613 let msg = self.endpoint.unsubscribe(request_id)?;
614 self.send_control(&msg).await
615 }
616
617 pub async fn fetch(
621 &mut self,
622 track_namespace: TrackNamespace,
623 track_name: Vec<u8>,
624 start_group: VarInt,
625 start_object: VarInt,
626 end_group: VarInt,
627 end_object: VarInt,
628 ) -> Result<VarInt, ConnectionError> {
629 let (req_id, msg) = self.endpoint.fetch(
630 track_namespace,
631 track_name,
632 start_group,
633 start_object,
634 end_group,
635 end_object,
636 )?;
637 self.send_control(&msg).await?;
638 Ok(req_id)
639 }
640
641 pub async fn joining_fetch(
643 &mut self,
644 joining_request_id: VarInt,
645 joining_start: VarInt,
646 ) -> Result<VarInt, ConnectionError> {
647 let (req_id, msg) = self.endpoint.joining_fetch(joining_request_id, joining_start)?;
648 self.send_control(&msg).await?;
649 Ok(req_id)
650 }
651
652 pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
654 let msg = self.endpoint.fetch_cancel(request_id)?;
655 self.send_control(&msg).await
656 }
657
658 pub async fn subscribe_namespace(
662 &mut self,
663 namespace_prefix: TrackNamespace,
664 subscribe_options: VarInt,
665 parameters: Vec<KeyValuePair>,
666 ) -> Result<VarInt, ConnectionError> {
667 let (req_id, msg) =
668 self.endpoint.subscribe_namespace(namespace_prefix, subscribe_options, parameters)?;
669 self.send_control(&msg).await?;
670 Ok(req_id)
671 }
672
673 pub async fn publish_namespace(
675 &mut self,
676 track_namespace: TrackNamespace,
677 parameters: Vec<KeyValuePair>,
678 ) -> Result<VarInt, ConnectionError> {
679 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace, parameters)?;
680 self.send_control(&msg).await?;
681 Ok(req_id)
682 }
683
684 pub async fn track_status(
688 &mut self,
689 track_namespace: TrackNamespace,
690 track_name: Vec<u8>,
691 parameters: Vec<KeyValuePair>,
692 ) -> Result<VarInt, ConnectionError> {
693 let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name, parameters)?;
694 self.send_control(&msg).await?;
695 Ok(req_id)
696 }
697
698 pub async fn publish(
702 &mut self,
703 track_namespace: TrackNamespace,
704 track_name: Vec<u8>,
705 track_alias: VarInt,
706 track_extensions: Vec<KeyValuePair>,
707 parameters: Vec<KeyValuePair>,
708 ) -> Result<VarInt, ConnectionError> {
709 let (req_id, msg) = self.endpoint.publish(
710 track_namespace,
711 track_name,
712 track_alias,
713 track_extensions,
714 parameters,
715 )?;
716 self.send_control(&msg).await?;
717 Ok(req_id)
718 }
719
720 pub async fn publish_done(
722 &mut self,
723 request_id: VarInt,
724 status_code: VarInt,
725 stream_count: VarInt,
726 reason_phrase: Vec<u8>,
727 ) -> Result<(), ConnectionError> {
728 let msg = self.endpoint.send_publish_done(
729 request_id,
730 status_code,
731 stream_count,
732 reason_phrase,
733 )?;
734 self.send_control(&msg).await
735 }
736
737 pub async fn open_subgroup_stream(
741 &self,
742 header: &AnySubgroupHeader,
743 ) -> Result<FramedSendStream, ConnectionError> {
744 let send = self.transport.open_uni().await?;
745 let mut framed = FramedSendStream::new(send, self.draft);
746 let sid = framed.stream_id();
747 framed.write_subgroup_header(header).await?;
748 self.emit(ClientEvent::StreamOpened {
749 direction: Direction::Send,
750 stream_kind: StreamKind::Subgroup,
751 stream_id: sid,
752 });
753 self.emit(ClientEvent::DataStreamHeader {
754 stream_id: sid,
755 direction: Direction::Send,
756 header: header.clone(),
757 });
758 Ok(framed)
759 }
760
761 pub async fn accept_subgroup_stream(
764 &self,
765 ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
766 let recv = self.transport.accept_uni().await?;
767 let mut framed = FramedRecvStream::new(recv, self.draft);
768 let sid = framed.stream_id();
769 let header = framed.read_subgroup_header().await?;
770 self.emit(ClientEvent::StreamOpened {
771 direction: Direction::Receive,
772 stream_kind: StreamKind::Subgroup,
773 stream_id: sid,
774 });
775 self.emit(ClientEvent::DataStreamHeader {
776 stream_id: sid,
777 direction: Direction::Receive,
778 header: header.clone(),
779 });
780 Ok((header, framed))
781 }
782
783 pub fn send_datagram(
785 &self,
786 header: &AnyDatagramHeader,
787 payload: &[u8],
788 ) -> Result<(), ConnectionError> {
789 let mut buf = Vec::new();
790 header.encode(&mut buf);
791 buf.extend_from_slice(payload);
792 self.emit(ClientEvent::DatagramReceived {
793 direction: Direction::Send,
794 header: header.clone(),
795 payload_len: payload.len(),
796 });
797 self.transport.send_datagram(bytes::Bytes::from(buf))?;
798 Ok(())
799 }
800
801 pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
803 let data = self.transport.recv_datagram().await?;
804 let mut cursor = &data[..];
805 let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
806 let consumed = data.len() - cursor.len();
807 let payload = data.slice(consumed..);
808 self.emit(ClientEvent::DatagramReceived {
809 direction: Direction::Receive,
810 header: header.clone(),
811 payload_len: payload.len(),
812 });
813 Ok((header, payload))
814 }
815
816 pub fn endpoint(&self) -> &Endpoint {
820 &self.endpoint
821 }
822
823 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
825 &mut self.endpoint
826 }
827
828 pub fn draft(&self) -> DraftVersion {
830 self.draft
831 }
832
833 pub fn close(&self, code: u32, reason: &[u8]) {
835 self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
836 self.transport.close(code, reason);
837 }
838}
839
840fn varint_len(first_byte: u8) -> usize {
842 1 << (first_byte >> 6)
843}
844
845#[derive(Debug)]
847struct SkipVerification;
848
849impl rustls::client::danger::ServerCertVerifier for SkipVerification {
850 fn verify_server_cert(
851 &self,
852 _end_entity: &rustls::pki_types::CertificateDer<'_>,
853 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
854 _server_name: &rustls::pki_types::ServerName<'_>,
855 _ocsp_response: &[u8],
856 _now: rustls::pki_types::UnixTime,
857 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
858 Ok(rustls::client::danger::ServerCertVerified::assertion())
859 }
860
861 fn verify_tls12_signature(
862 &self,
863 _message: &[u8],
864 _cert: &rustls::pki_types::CertificateDer<'_>,
865 _dcs: &rustls::DigitallySignedStruct,
866 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
867 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
868 }
869
870 fn verify_tls13_signature(
871 &self,
872 _message: &[u8],
873 _cert: &rustls::pki_types::CertificateDer<'_>,
874 _dcs: &rustls::DigitallySignedStruct,
875 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
876 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
877 }
878
879 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
880 vec![
881 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
882 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
883 rustls::SignatureScheme::ED25519,
884 rustls::SignatureScheme::RSA_PSS_SHA256,
885 rustls::SignatureScheme::RSA_PSS_SHA384,
886 rustls::SignatureScheme::RSA_PSS_SHA512,
887 ]
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894
895 #[test]
896 fn varint_len_single_byte() {
897 assert_eq!(varint_len(0x00), 1);
898 assert_eq!(varint_len(0x3F), 1);
899 }
900
901 #[test]
902 fn varint_len_two_bytes() {
903 assert_eq!(varint_len(0x40), 2);
904 assert_eq!(varint_len(0x7F), 2);
905 }
906
907 #[test]
908 fn varint_len_four_bytes() {
909 assert_eq!(varint_len(0x80), 4);
910 assert_eq!(varint_len(0xBF), 4);
911 }
912
913 #[test]
914 fn varint_len_eight_bytes() {
915 assert_eq!(varint_len(0xC0), 8);
916 assert_eq!(varint_len(0xFF), 8);
917 }
918
919 #[test]
920 fn client_config_alpn_quic_draft16() {
921 let config = ClientConfig {
922 draft: DraftVersion::Draft16,
923 transport: TransportType::Quic,
924 skip_cert_verification: false,
925 ca_certs: Vec::new(),
926 setup_parameters: Vec::new(),
927 };
928 assert_eq!(config.alpn(), vec![b"moqt-16".to_vec()]);
929 }
930
931 #[test]
932 fn client_config_alpn_webtransport() {
933 let config = ClientConfig {
934 draft: DraftVersion::Draft16,
935 transport: TransportType::WebTransport { url: "https://example.com".to_string() },
936 skip_cert_verification: false,
937 ca_certs: Vec::new(),
938 setup_parameters: Vec::new(),
939 };
940 assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
941 }
942
943 #[test]
944 fn moqt_alpn_value() {
945 assert_eq!(MOQT_ALPN, b"moq-00");
946 }
947
948 #[test]
949 fn transport_type_debug() {
950 let quic = TransportType::Quic;
951 assert!(format!("{quic:?}").contains("Quic"));
952
953 let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
954 assert!(format!("{wt:?}").contains("WebTransport"));
955 }
956}