1use std::sync::Arc;
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::draft14::endpoint::{Endpoint, EndpointError};
6use crate::draft14::event::{ClientEvent, Direction, StreamKind};
7use crate::draft14::observer::ConnectionObserver;
8use crate::draft14::session::request_id::Role;
9use crate::transport::quic::QuicTransport;
10use crate::transport::{RecvStream, SendStream, Transport, TransportError};
11use moqtap_codec::dispatch::{
12 AnyControlMessage, AnyDatagramHeader, AnyFetchHeader, AnySubgroupHeader,
13};
14use moqtap_codec::draft14::data_stream::{FetchObject, SubgroupObject, SubgroupObjectReader};
15use moqtap_codec::draft14::message::ControlMessage;
16use moqtap_codec::error::CodecError;
17use moqtap_codec::types::*;
18use moqtap_codec::varint::VarInt;
19use moqtap_codec::version::DraftVersion;
20
21pub const MOQT_ALPN: &[u8] = b"moq-00";
23
24#[derive(Debug, thiserror::Error)]
26pub enum ConnectionError {
27 #[error("endpoint error: {0}")]
29 Endpoint(#[from] EndpointError),
30 #[error("codec error: {0}")]
32 Codec(#[from] CodecError),
33 #[error("transport error: {0}")]
35 Transport(#[from] TransportError),
36 #[error("varint error: {0}")]
38 VarInt(#[from] moqtap_codec::varint::VarIntError),
39 #[error("control stream not open")]
41 NoControlStream,
42 #[error("unexpected end of stream")]
44 UnexpectedEnd,
45 #[error("stream finished")]
47 StreamFinished,
48 #[error("invalid server address: {0}")]
50 InvalidAddress(String),
51 #[error("TLS config error: {0}")]
53 TlsConfig(String),
54 #[error("data stream state error: {0}")]
56 DataStreamState(&'static str),
57}
58
59#[derive(Debug, Clone)]
61pub enum TransportType {
62 Quic,
64 WebTransport {
66 url: String,
68 },
69}
70
71pub struct ClientConfig {
75 pub draft: DraftVersion,
77 pub additional_versions: Vec<DraftVersion>,
80 pub transport: TransportType,
82 pub skip_cert_verification: bool,
84 pub ca_certs: Vec<Vec<u8>>,
86 pub setup_parameters: Vec<moqtap_codec::kvp::KeyValuePair>,
88}
89
90impl ClientConfig {
91 pub fn supported_versions(&self) -> Vec<VarInt> {
94 let mut versions = vec![self.draft.version_varint()];
95 for v in &self.additional_versions {
96 let varint = v.version_varint();
97 if !versions.contains(&varint) {
98 versions.push(varint);
99 }
100 }
101 versions
102 }
103
104 pub fn alpn(&self) -> Vec<Vec<u8>> {
106 match &self.transport {
107 TransportType::Quic => vec![self.draft.quic_alpn().to_vec()],
108 TransportType::WebTransport { .. } => vec![b"h3".to_vec()],
109 }
110 }
111}
112
113pub struct FramedSendStream {
115 inner: SendStream,
116 draft: DraftVersion,
117 subgroup_io: Option<SubgroupObjectReader>,
122}
123
124impl FramedSendStream {
125 pub fn new(inner: SendStream, draft: DraftVersion) -> Self {
127 Self { inner, draft, subgroup_io: None }
128 }
129
130 pub fn stream_id(&self) -> u64 {
132 self.inner.stream_id()
133 }
134
135 pub async fn write_control(
138 &mut self,
139 msg: &AnyControlMessage,
140 ) -> Result<Vec<u8>, ConnectionError> {
141 let mut buf = Vec::new();
142 msg.encode(&mut buf)?;
143 self.inner.write_all(&buf).await?;
144 Ok(buf)
145 }
146
147 pub async fn write_subgroup_header(
150 &mut self,
151 header: &AnySubgroupHeader,
152 ) -> Result<(), ConnectionError> {
153 let mut buf = Vec::new();
154 header.encode(&mut buf);
155 self.inner.write_all(&buf).await?;
156 if let AnySubgroupHeader::Draft14(ref d14) = header {
157 self.subgroup_io = Some(SubgroupObjectReader::new(d14));
158 }
159 Ok(())
160 }
161
162 pub async fn write_fetch_header(
164 &mut self,
165 header: &AnyFetchHeader,
166 ) -> Result<(), ConnectionError> {
167 let mut buf = Vec::new();
168 header.encode(&mut buf);
169 self.inner.write_all(&buf).await?;
170 Ok(())
171 }
172
173 pub async fn write_subgroup_object(
178 &mut self,
179 object: &SubgroupObject,
180 ) -> Result<(), ConnectionError> {
181 let writer = self
182 .subgroup_io
183 .as_mut()
184 .ok_or(ConnectionError::DataStreamState("subgroup header not written yet"))?;
185 let mut buf = Vec::new();
186 writer.write_object(object, &mut buf)?;
187 self.inner.write_all(&buf).await?;
188 Ok(())
189 }
190
191 pub async fn write_fetch_object(
195 &mut self,
196 object: &FetchObject,
197 ) -> Result<(), ConnectionError> {
198 let mut buf = Vec::new();
199 object.encode(&mut buf);
200 self.inner.write_all(&buf).await?;
201 Ok(())
202 }
203
204 pub async fn finish(&mut self) -> Result<(), ConnectionError> {
206 self.inner.finish()?;
207 Ok(())
208 }
209
210 pub fn draft(&self) -> DraftVersion {
212 self.draft
213 }
214}
215
216pub struct FramedRecvStream {
218 inner: RecvStream,
219 buf: BytesMut,
220 draft: DraftVersion,
221 subgroup_io: Option<SubgroupObjectReader>,
226}
227
228impl FramedRecvStream {
229 pub fn new(inner: RecvStream, draft: DraftVersion) -> Self {
231 Self { inner, buf: BytesMut::with_capacity(4096), draft, subgroup_io: None }
232 }
233
234 pub fn stream_id(&self) -> u64 {
236 self.inner.stream_id()
237 }
238
239 async fn fill(&mut self) -> Result<bool, ConnectionError> {
241 let mut tmp = [0u8; 4096];
242 match self.inner.read(&mut tmp).await {
243 Ok(Some(n)) => {
244 self.buf.extend_from_slice(&tmp[..n]);
245 Ok(true)
246 }
247 Ok(None) => Ok(false),
248 Err(e) => Err(ConnectionError::Transport(e)),
249 }
250 }
251
252 async fn ensure(&mut self, n: usize) -> Result<(), ConnectionError> {
254 while self.buf.len() < n {
255 if !self.fill().await? {
256 return Err(ConnectionError::UnexpectedEnd);
257 }
258 }
259 Ok(())
260 }
261
262 pub async fn read_control(
268 &mut self,
269 capture_raw: bool,
270 ) -> Result<(AnyControlMessage, Option<Vec<u8>>), ConnectionError> {
271 self.ensure(1).await?;
273 let type_len = varint_len(self.buf[0]);
274 self.ensure(type_len).await?;
275
276 let mut cursor = &self.buf[..type_len];
277 let _type_id = VarInt::decode(&mut cursor)?;
278
279 let (payload_len, len_field_size) = if self.draft.uses_fixed_length_framing() {
281 self.ensure(type_len + 2).await?;
282 let hi = self.buf[type_len] as usize;
283 let lo = self.buf[type_len + 1] as usize;
284 ((hi << 8) | lo, 2)
285 } else {
286 self.ensure(type_len + 1).await?;
287 let payload_len_start = type_len;
288 let payload_len_varint_len = varint_len(self.buf[payload_len_start]);
289 self.ensure(type_len + payload_len_varint_len).await?;
290 let mut cursor = &self.buf[payload_len_start..type_len + payload_len_varint_len];
291 let payload_len = VarInt::decode(&mut cursor)?.into_inner() as usize;
292 (payload_len, payload_len_varint_len)
293 };
294
295 let total = type_len + len_field_size + payload_len;
297 self.ensure(total).await?;
298
299 let raw = capture_raw.then(|| self.buf[..total].to_vec());
301
302 let mut frame = &self.buf[..total];
304 let msg = AnyControlMessage::decode(self.draft, &mut frame)?;
305 self.buf.advance(total);
306 Ok((msg, raw))
307 }
308
309 pub async fn read_subgroup_header(&mut self) -> Result<AnySubgroupHeader, ConnectionError> {
313 self.ensure(1).await?;
314 loop {
315 let mut cursor = &self.buf[..];
316 match AnySubgroupHeader::decode(self.draft, &mut cursor) {
317 Ok(header) => {
318 let consumed = self.buf.len() - cursor.remaining();
319 self.buf.advance(consumed);
320 if let AnySubgroupHeader::Draft14(ref d14) = header {
321 self.subgroup_io = Some(SubgroupObjectReader::new(d14));
322 }
323 return Ok(header);
324 }
325 Err(CodecError::UnexpectedEnd) => {
326 if !self.fill().await? {
327 return Err(ConnectionError::UnexpectedEnd);
328 }
329 }
330 Err(e) => return Err(ConnectionError::Codec(e)),
331 }
332 }
333 }
334
335 pub async fn read_fetch_header(&mut self) -> Result<AnyFetchHeader, ConnectionError> {
337 self.ensure(1).await?;
338 loop {
339 let mut cursor = &self.buf[..];
340 match AnyFetchHeader::decode(self.draft, &mut cursor) {
341 Ok(header) => {
342 let consumed = self.buf.len() - cursor.remaining();
343 self.buf.advance(consumed);
344 return Ok(header);
345 }
346 Err(CodecError::UnexpectedEnd) => {
347 if !self.fill().await? {
348 return Err(ConnectionError::UnexpectedEnd);
349 }
350 }
351 Err(e) => return Err(ConnectionError::Codec(e)),
352 }
353 }
354 }
355
356 pub async fn read_subgroup_object(&mut self) -> Result<SubgroupObject, ConnectionError> {
362 if self.subgroup_io.is_none() {
363 return Err(ConnectionError::DataStreamState("subgroup header not read yet"));
364 }
365 loop {
366 let reader = self.subgroup_io.as_mut().unwrap();
367 let mut probe = reader.clone();
368 let mut cursor = &self.buf[..];
369 match probe.read_object(&mut cursor) {
370 Ok(obj) => {
371 let consumed = self.buf.len() - cursor.remaining();
372 self.buf.advance(consumed);
373 *reader = probe;
375 return Ok(obj);
376 }
377 Err(CodecError::UnexpectedEnd) => {
378 if !self.fill().await? {
379 return Err(ConnectionError::UnexpectedEnd);
380 }
381 }
382 Err(e) => return Err(ConnectionError::Codec(e)),
383 }
384 }
385 }
386
387 pub async fn read_fetch_object(&mut self) -> Result<FetchObject, ConnectionError> {
391 loop {
392 let mut cursor = &self.buf[..];
393 match FetchObject::decode(&mut cursor) {
394 Ok(obj) => {
395 let consumed = self.buf.len() - cursor.remaining();
396 self.buf.advance(consumed);
397 return Ok(obj);
398 }
399 Err(CodecError::UnexpectedEnd) => {
400 if !self.fill().await? {
401 return Err(ConnectionError::UnexpectedEnd);
402 }
403 }
404 Err(e) => return Err(ConnectionError::Codec(e)),
405 }
406 }
407 }
408
409 pub fn draft(&self) -> DraftVersion {
411 self.draft
412 }
413}
414
415pub struct Connection {
418 transport: Transport,
419 endpoint: Endpoint,
420 draft: DraftVersion,
421 control_send: Option<FramedSendStream>,
422 control_recv: Option<FramedRecvStream>,
423 observer: Option<Box<dyn ConnectionObserver>>,
424}
425
426impl Connection {
427 pub async fn connect(addr: &str, config: ClientConfig) -> Result<Self, ConnectionError> {
433 let draft = config.draft;
434 let transport = match &config.transport {
435 TransportType::Quic => Self::connect_quic(addr, &config).await?,
436 TransportType::WebTransport { url } => {
437 let url = url.clone();
438 Self::connect_webtransport(&url, &config).await?
439 }
440 };
441
442 let (send, recv) = transport.open_bi().await?;
444 let mut control_send = FramedSendStream::new(send, draft);
445 let mut control_recv = FramedRecvStream::new(recv, draft);
446
447 let mut endpoint = Endpoint::new(Role::Client);
449 endpoint.connect()?;
450 let setup_msg = endpoint
451 .send_client_setup(config.supported_versions(), config.setup_parameters.clone())?;
452 let any_setup = AnyControlMessage::Draft14(setup_msg);
453 let _raw_setup = control_send.write_control(&any_setup).await?;
454
455 let (server_setup, _raw_server_setup) = control_recv.read_control(false).await?;
456 match &server_setup {
458 AnyControlMessage::Draft14(ControlMessage::ServerSetup(ref ss)) => {
459 endpoint.receive_server_setup(ss)?;
460 }
461 _ => {
462 return Err(ConnectionError::Endpoint(EndpointError::NotActive));
463 }
464 }
465
466 let conn = Self {
467 transport,
468 endpoint,
469 draft,
470 control_send: Some(control_send),
471 control_recv: Some(control_recv),
472 observer: None,
473 };
474
475 if let Some(v) = conn.endpoint.negotiated_version() {
477 conn.emit(ClientEvent::SetupComplete { negotiated_version: v.into_inner() });
478 }
479
480 Ok(conn)
481 }
482
483 async fn connect_quic(addr: &str, config: &ClientConfig) -> Result<Transport, ConnectionError> {
485 let server_addr = addr.parse().map_err(|e: std::net::AddrParseError| {
486 ConnectionError::InvalidAddress(e.to_string())
487 })?;
488
489 let mut tls_config = if config.skip_cert_verification {
491 rustls::ClientConfig::builder()
492 .dangerous()
493 .with_custom_certificate_verifier(Arc::new(SkipVerification))
494 .with_no_client_auth()
495 } else {
496 let mut roots = rustls::RootCertStore::empty();
497 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
498 for der in &config.ca_certs {
499 roots
500 .add(rustls::pki_types::CertificateDer::from(der.clone()))
501 .map_err(|e| ConnectionError::TlsConfig(format!("bad CA cert: {e}")))?;
502 }
503 rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth()
504 };
505
506 tls_config.alpn_protocols = config.alpn();
507
508 let quic_config: quinn::crypto::rustls::QuicClientConfig =
509 tls_config.try_into().map_err(|e| ConnectionError::TlsConfig(format!("{e}")))?;
510 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
511
512 let mut quinn_endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
513 .map_err(|e| ConnectionError::InvalidAddress(e.to_string()))?;
514 quinn_endpoint.set_default_client_config(client_config);
515
516 let server_name = addr.split(':').next().unwrap_or("localhost").to_string();
517
518 let quic = quinn_endpoint
519 .connect(server_addr, &server_name)
520 .map_err(TransportError::from)?
521 .await
522 .map_err(TransportError::from)?;
523
524 Ok(Transport::Quic(QuicTransport::new(quic)))
525 }
526
527 #[cfg(feature = "webtransport")]
529 async fn connect_webtransport(
530 url: &str,
531 config: &ClientConfig,
532 ) -> Result<Transport, ConnectionError> {
533 use crate::transport::webtransport::WebTransportTransport;
534
535 let wt_config = if config.skip_cert_verification {
536 wtransport::ClientConfig::builder()
537 .with_bind_default()
538 .with_no_cert_validation()
539 .build()
540 } else {
541 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
542 };
543
544 let endpoint = wtransport::Endpoint::client(wt_config)
545 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
546
547 let connection = endpoint
548 .connect(url)
549 .await
550 .map_err(|e| ConnectionError::Transport(TransportError::Connect(e.to_string())))?;
551
552 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
553 }
554
555 #[cfg(not(feature = "webtransport"))]
557 async fn connect_webtransport(
558 _url: &str,
559 _config: &ClientConfig,
560 ) -> Result<Transport, ConnectionError> {
561 Err(ConnectionError::Transport(TransportError::Connect(
562 "webtransport feature not enabled".into(),
563 )))
564 }
565
566 pub fn set_observer(&mut self, observer: Box<dyn ConnectionObserver>) {
570 self.observer = Some(observer);
571 }
572
573 pub fn clear_observer(&mut self) {
575 self.observer = None;
576 }
577
578 fn emit(&self, event: ClientEvent) {
580 if let Some(ref obs) = self.observer {
581 obs.on_event_owned(event);
582 }
583 }
584
585 pub async fn send_control(&mut self, msg: &ControlMessage) -> Result<(), ConnectionError> {
591 let any = AnyControlMessage::Draft14(msg.clone());
592 let send = self.control_send.as_mut().ok_or(ConnectionError::NoControlStream)?;
593 let raw = send.write_control(&any).await?;
594 self.emit(ClientEvent::ControlMessage {
595 direction: Direction::Send,
596 message: any,
597 raw: Some(raw),
598 });
599 Ok(())
600 }
601
602 pub async fn recv_control(&mut self) -> Result<ControlMessage, ConnectionError> {
607 let recv = self.control_recv.as_mut().ok_or(ConnectionError::NoControlStream)?;
608 let capture_raw = self.observer.is_some();
609 let (any, raw) = recv.read_control(capture_raw).await?;
610 if capture_raw {
611 self.emit(ClientEvent::ControlMessage {
612 direction: Direction::Receive,
613 message: any.clone(),
614 raw,
615 });
616 }
617 match any {
619 AnyControlMessage::Draft14(msg) => Ok(msg),
620 _ => Err(ConnectionError::Codec(CodecError::UnknownMessageType(0))),
621 }
622 }
623
624 pub async fn recv_and_dispatch(&mut self) -> Result<ControlMessage, ConnectionError> {
627 let msg = self.recv_control().await?;
628 self.endpoint.receive_message(msg.clone())?;
629
630 if let ControlMessage::GoAway(ref ga) = msg {
632 self.emit(ClientEvent::Draining { new_session_uri: ga.new_session_uri.clone() });
633 }
634
635 Ok(msg)
636 }
637
638 pub async fn subscribe(
642 &mut self,
643 track_namespace: TrackNamespace,
644 track_name: Vec<u8>,
645 subscriber_priority: u8,
646 group_order: GroupOrder,
647 filter_type: FilterType,
648 ) -> Result<VarInt, ConnectionError> {
649 let (req_id, msg) = self.endpoint.subscribe(
650 track_namespace,
651 track_name,
652 subscriber_priority,
653 group_order,
654 filter_type,
655 )?;
656 self.send_control(&msg).await?;
657 Ok(req_id)
658 }
659
660 pub async fn unsubscribe(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
662 let msg = self.endpoint.unsubscribe(request_id)?;
663 self.send_control(&msg).await
664 }
665
666 pub async fn fetch(
670 &mut self,
671 track_namespace: TrackNamespace,
672 track_name: Vec<u8>,
673 start_group: VarInt,
674 start_object: VarInt,
675 ) -> Result<VarInt, ConnectionError> {
676 let (req_id, msg) =
677 self.endpoint.fetch(track_namespace, track_name, start_group, start_object)?;
678 self.send_control(&msg).await?;
679 Ok(req_id)
680 }
681
682 pub async fn fetch_cancel(&mut self, request_id: VarInt) -> Result<(), ConnectionError> {
684 let msg = self.endpoint.fetch_cancel(request_id)?;
685 self.send_control(&msg).await
686 }
687
688 pub async fn subscribe_namespace(
692 &mut self,
693 track_namespace: TrackNamespace,
694 ) -> Result<VarInt, ConnectionError> {
695 let (req_id, msg) = self.endpoint.subscribe_namespace(track_namespace)?;
696 self.send_control(&msg).await?;
697 Ok(req_id)
698 }
699
700 pub async fn publish_namespace(
702 &mut self,
703 track_namespace: TrackNamespace,
704 ) -> Result<VarInt, ConnectionError> {
705 let (req_id, msg) = self.endpoint.publish_namespace(track_namespace)?;
706 self.send_control(&msg).await?;
707 Ok(req_id)
708 }
709
710 pub async fn track_status(
714 &mut self,
715 track_namespace: TrackNamespace,
716 track_name: Vec<u8>,
717 ) -> Result<VarInt, ConnectionError> {
718 let (req_id, msg) = self.endpoint.track_status(track_namespace, track_name)?;
719 self.send_control(&msg).await?;
720 Ok(req_id)
721 }
722
723 pub async fn publish(
727 &mut self,
728 track_namespace: TrackNamespace,
729 track_name: Vec<u8>,
730 forward: Forward,
731 ) -> Result<VarInt, ConnectionError> {
732 let (req_id, msg) = self.endpoint.publish(track_namespace, track_name, forward)?;
733 self.send_control(&msg).await?;
734 Ok(req_id)
735 }
736
737 pub async fn publish_done(
739 &mut self,
740 request_id: VarInt,
741 status_code: VarInt,
742 reason_phrase: Vec<u8>,
743 ) -> Result<(), ConnectionError> {
744 let msg = self.endpoint.send_publish_done(request_id, status_code, reason_phrase)?;
745 self.send_control(&msg).await
746 }
747
748 pub async fn open_subgroup_stream(
752 &self,
753 header: &AnySubgroupHeader,
754 ) -> Result<FramedSendStream, ConnectionError> {
755 let send = self.transport.open_uni().await?;
756 let mut framed = FramedSendStream::new(send, self.draft);
757 let sid = framed.stream_id();
758 framed.write_subgroup_header(header).await?;
759 self.emit(ClientEvent::StreamOpened {
760 direction: Direction::Send,
761 stream_kind: StreamKind::Subgroup,
762 stream_id: sid,
763 });
764 self.emit(ClientEvent::DataStreamHeader {
765 stream_id: sid,
766 direction: Direction::Send,
767 header: header.clone(),
768 });
769 Ok(framed)
770 }
771
772 pub async fn accept_subgroup_stream(
775 &self,
776 ) -> Result<(AnySubgroupHeader, FramedRecvStream), ConnectionError> {
777 let recv = self.transport.accept_uni().await?;
778 let mut framed = FramedRecvStream::new(recv, self.draft);
779 let sid = framed.stream_id();
780 let header = framed.read_subgroup_header().await?;
781 self.emit(ClientEvent::StreamOpened {
782 direction: Direction::Receive,
783 stream_kind: StreamKind::Subgroup,
784 stream_id: sid,
785 });
786 self.emit(ClientEvent::DataStreamHeader {
787 stream_id: sid,
788 direction: Direction::Receive,
789 header: header.clone(),
790 });
791 Ok((header, framed))
792 }
793
794 pub fn send_datagram(
796 &self,
797 header: &AnyDatagramHeader,
798 payload: &[u8],
799 ) -> Result<(), ConnectionError> {
800 let mut buf = Vec::new();
801 header.encode(&mut buf);
802 buf.extend_from_slice(payload);
803 self.emit(ClientEvent::DatagramReceived {
804 direction: Direction::Send,
805 header: header.clone(),
806 payload_len: payload.len(),
807 });
808 self.transport.send_datagram(bytes::Bytes::from(buf))?;
809 Ok(())
810 }
811
812 pub async fn recv_datagram(&self) -> Result<(AnyDatagramHeader, Bytes), ConnectionError> {
814 let data = self.transport.recv_datagram().await?;
815 let mut cursor = &data[..];
816 let header = AnyDatagramHeader::decode(self.draft, &mut cursor)?;
817 let consumed = data.len() - cursor.len();
818 let payload = data.slice(consumed..);
819 self.emit(ClientEvent::DatagramReceived {
820 direction: Direction::Receive,
821 header: header.clone(),
822 payload_len: payload.len(),
823 });
824 Ok((header, payload))
825 }
826
827 pub fn endpoint(&self) -> &Endpoint {
831 &self.endpoint
832 }
833
834 pub fn endpoint_mut(&mut self) -> &mut Endpoint {
836 &mut self.endpoint
837 }
838
839 pub fn negotiated_version(&self) -> Option<VarInt> {
841 self.endpoint.negotiated_version()
842 }
843
844 pub fn draft(&self) -> DraftVersion {
846 self.draft
847 }
848
849 pub fn close(&self, code: u32, reason: &[u8]) {
851 self.emit(ClientEvent::Closed { code, reason: reason.to_vec() });
852 self.transport.close(code, reason);
853 }
854}
855
856fn varint_len(first_byte: u8) -> usize {
858 1 << (first_byte >> 6)
859}
860
861#[derive(Debug)]
863struct SkipVerification;
864
865impl rustls::client::danger::ServerCertVerifier for SkipVerification {
866 fn verify_server_cert(
867 &self,
868 _end_entity: &rustls::pki_types::CertificateDer<'_>,
869 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
870 _server_name: &rustls::pki_types::ServerName<'_>,
871 _ocsp_response: &[u8],
872 _now: rustls::pki_types::UnixTime,
873 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
874 Ok(rustls::client::danger::ServerCertVerified::assertion())
875 }
876
877 fn verify_tls12_signature(
878 &self,
879 _message: &[u8],
880 _cert: &rustls::pki_types::CertificateDer<'_>,
881 _dcs: &rustls::DigitallySignedStruct,
882 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
883 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
884 }
885
886 fn verify_tls13_signature(
887 &self,
888 _message: &[u8],
889 _cert: &rustls::pki_types::CertificateDer<'_>,
890 _dcs: &rustls::DigitallySignedStruct,
891 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
892 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
893 }
894
895 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
896 vec![
897 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
898 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
899 rustls::SignatureScheme::ED25519,
900 rustls::SignatureScheme::RSA_PSS_SHA256,
901 rustls::SignatureScheme::RSA_PSS_SHA384,
902 rustls::SignatureScheme::RSA_PSS_SHA512,
903 ]
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910
911 #[test]
912 fn varint_len_single_byte() {
913 assert_eq!(varint_len(0x00), 1);
915 assert_eq!(varint_len(0x3F), 1);
916 }
917
918 #[test]
919 fn varint_len_two_bytes() {
920 assert_eq!(varint_len(0x40), 2);
922 assert_eq!(varint_len(0x7F), 2);
923 }
924
925 #[test]
926 fn varint_len_four_bytes() {
927 assert_eq!(varint_len(0x80), 4);
929 assert_eq!(varint_len(0xBF), 4);
930 }
931
932 #[test]
933 fn varint_len_eight_bytes() {
934 assert_eq!(varint_len(0xC0), 8);
936 assert_eq!(varint_len(0xFF), 8);
937 }
938
939 #[test]
940 fn client_config_supported_versions_draft14() {
941 let config = ClientConfig {
942 draft: DraftVersion::Draft14,
943 additional_versions: Vec::new(),
944 transport: TransportType::Quic,
945 skip_cert_verification: false,
946 ca_certs: Vec::new(),
947 setup_parameters: Vec::new(),
948 };
949 let versions = config.supported_versions();
950 assert_eq!(versions.len(), 1);
951 assert_eq!(versions[0].into_inner(), 0xff000000 + 14);
952 }
953
954 #[test]
955 fn client_config_supported_versions_draft07() {
956 let config = ClientConfig {
957 draft: DraftVersion::Draft07,
958 additional_versions: Vec::new(),
959 transport: TransportType::Quic,
960 skip_cert_verification: false,
961 ca_certs: Vec::new(),
962 setup_parameters: Vec::new(),
963 };
964 let versions = config.supported_versions();
965 assert_eq!(versions.len(), 1);
966 assert_eq!(versions[0].into_inner(), 0xff000000 + 7);
967 }
968
969 #[test]
970 fn client_config_alpn_quic() {
971 let config = ClientConfig {
972 draft: DraftVersion::Draft14,
973 additional_versions: Vec::new(),
974 transport: TransportType::Quic,
975 skip_cert_verification: false,
976 ca_certs: Vec::new(),
977 setup_parameters: Vec::new(),
978 };
979 assert_eq!(config.alpn(), vec![b"moq-00".to_vec()]);
980 }
981
982 #[test]
983 fn client_config_alpn_webtransport() {
984 let config = ClientConfig {
985 draft: DraftVersion::Draft14,
986 additional_versions: Vec::new(),
987 transport: TransportType::WebTransport { url: "https://example.com".to_string() },
988 skip_cert_verification: false,
989 ca_certs: Vec::new(),
990 setup_parameters: Vec::new(),
991 };
992 assert_eq!(config.alpn(), vec![b"h3".to_vec()]);
993 }
994
995 #[test]
996 fn moqt_alpn_value() {
997 assert_eq!(MOQT_ALPN, b"moq-00");
998 }
999
1000 #[test]
1001 fn transport_type_debug() {
1002 let quic = TransportType::Quic;
1003 assert!(format!("{quic:?}").contains("Quic"));
1004
1005 let wt = TransportType::WebTransport { url: "https://example.com".to_string() };
1006 assert!(format!("{wt:?}").contains("WebTransport"));
1007 }
1008}