1use std::sync::Arc;
4
5use bytes::{Bytes, BytesMut};
6use tokio::task::JoinSet;
7use tokio_util::sync::CancellationToken;
8
9use moqtap_client::transport::quic::QuicTransport;
10use moqtap_client::transport::{RecvStream, SendStream, Transport};
11use moqtap_codec::dispatch::AnyDatagramHeader;
12use moqtap_codec::varint::VarInt;
13use moqtap_codec::version::DraftVersion;
14
15use crate::error::ProxyError;
16use crate::event::{ProxyEvent, ProxySide, SessionId};
17use crate::hook::ProxyHook;
18use crate::observer::ProxyObserver;
19use crate::parser::control::{ControlStreamParser, ParseResult};
20use crate::parser::data::{DataStreamParser, DataStreamType};
21
22#[derive(Debug, Clone)]
24pub enum UpstreamTransportType {
25 Quic,
27 WebTransport {
29 url: String,
31 },
32}
33
34pub struct ProxySessionConfig {
36 pub draft: DraftVersion,
38 pub upstream_transport: UpstreamTransportType,
40 pub upstream_addr: String,
42 pub skip_upstream_cert_verify: bool,
44 pub upstream_ca_certs: Vec<Vec<u8>>,
46 pub upstream_connect_timeout_secs: u64,
48}
49
50impl ProxySessionConfig {
51 pub fn upstream_alpn(&self, client_alpn: &[u8]) -> Vec<Vec<u8>> {
58 match &self.upstream_transport {
59 UpstreamTransportType::Quic => {
60 if client_alpn.is_empty() {
61 vec![self.draft.quic_alpn().to_vec()]
62 } else {
63 vec![client_alpn.to_vec()]
64 }
65 }
66 UpstreamTransportType::WebTransport { .. } => vec![b"h3".to_vec()],
67 }
68 }
69}
70
71impl Default for ProxySessionConfig {
72 fn default() -> Self {
73 Self {
74 draft: DraftVersion::Draft14,
75 upstream_transport: UpstreamTransportType::Quic,
76 upstream_addr: String::new(),
77 skip_upstream_cert_verify: false,
78 upstream_ca_certs: Vec::new(),
79 upstream_connect_timeout_secs: 0,
80 }
81 }
82}
83
84pub struct ProxySession {
87 session_id: SessionId,
88 config: ProxySessionConfig,
89 client_alpn: Vec<u8>,
93 observer: Arc<dyn ProxyObserver>,
94 hook: Arc<dyn ProxyHook>,
95 cancel: CancellationToken,
96}
97
98impl ProxySession {
99 pub fn new(
104 session_id: SessionId,
105 config: ProxySessionConfig,
106 client_alpn: Vec<u8>,
107 observer: Arc<dyn ProxyObserver>,
108 hook: Arc<dyn ProxyHook>,
109 cancel: CancellationToken,
110 ) -> Self {
111 Self { session_id, config, client_alpn, observer, hook, cancel }
112 }
113
114 pub async fn run(self, client_conn: quinn::Connection) -> Result<(), ProxyError> {
116 let client = Transport::Quic(QuicTransport::new(client_conn));
117 self.run_with_transport(client).await
118 }
119
120 #[cfg(feature = "webtransport")]
122 pub async fn run_webtransport(
123 self,
124 client_conn: wtransport::Connection,
125 ) -> Result<(), ProxyError> {
126 use moqtap_client::transport::webtransport::WebTransportTransport;
127 let client = Transport::WebTransport(WebTransportTransport::new(client_conn));
128 self.run_with_transport(client).await
129 }
130
131 fn initial_draft(&self) -> DraftVersion {
137 DraftVersion::from_alpn(&self.client_alpn).unwrap_or(self.config.draft)
138 }
139
140 fn draft_is_fixed(&self) -> bool {
143 DraftVersion::from_alpn(&self.client_alpn).is_some()
144 }
145
146 async fn run_with_transport(self, client: Transport) -> Result<(), ProxyError> {
152 let relay = self.connect_upstream().await?;
154
155 let client = Arc::new(client);
156 let relay = Arc::new(relay);
157
158 let mut tasks: JoinSet<Result<(), ProxyError>> = JoinSet::new();
159
160 let initial_draft = self.initial_draft();
161 let draft_is_fixed = self.draft_is_fixed();
162
163 let observer_enabled = self.observer.wants_events();
164 let control_mutation = self.hook.wants_control_mutation();
165 let base_ctx = ForwardCtx {
166 session_id: self.session_id,
167 draft: initial_draft,
168 draft_is_fixed,
169 observer: Arc::clone(&self.observer),
170 hook: Arc::clone(&self.hook),
171 cancel: self.cancel.clone(),
172 observer_enabled,
173 control_mutation,
174 };
175
176 {
178 let client = Arc::clone(&client);
179 let relay = Arc::clone(&relay);
180 let ctx = base_ctx.clone();
181 tasks.spawn(async move { forward_control_stream(&client, &relay, &ctx).await });
182 }
183
184 {
186 let client = Arc::clone(&client);
187 let relay = Arc::clone(&relay);
188 let ctx = base_ctx.clone();
189 tasks.spawn(async move {
190 forward_uni_streams(&client, &relay, ProxySide::ClientToProxy, &ctx).await
191 });
192 }
193
194 {
196 let client = Arc::clone(&client);
197 let relay = Arc::clone(&relay);
198 let ctx = base_ctx.clone();
199 tasks.spawn(async move {
200 forward_uni_streams(&relay, &client, ProxySide::RelayToProxy, &ctx).await
201 });
202 }
203
204 {
206 let client = Arc::clone(&client);
207 let relay = Arc::clone(&relay);
208 let ctx = base_ctx.clone();
209 tasks.spawn(async move {
210 forward_datagrams(&client, &relay, ProxySide::ClientToProxy, &ctx).await
211 });
212 }
213
214 {
216 let client = Arc::clone(&client);
217 let relay = Arc::clone(&relay);
218 let ctx = base_ctx.clone();
219 tasks.spawn(async move {
220 forward_datagrams(&relay, &client, ProxySide::RelayToProxy, &ctx).await
221 });
222 }
223
224 let first_result = tasks.join_next().await;
226
227 self.cancel.cancel();
229 tasks.shutdown().await;
230
231 let reason = match &first_result {
233 Some(Ok(Ok(()))) => "completed".to_string(),
234 Some(Ok(Err(e))) => format!("{e}"),
235 Some(Err(e)) => format!("task panic: {e}"),
236 None => "no tasks".to_string(),
237 };
238 if self.observer.wants_events() {
239 self.observer
240 .on_event(&ProxyEvent::SessionEnded { session_id: self.session_id, reason });
241 }
242
243 client.close(0, b"proxy session ended");
245 relay.close(0, b"proxy session ended");
246
247 match first_result {
248 Some(Ok(Ok(()))) | None => Ok(()),
249 Some(Ok(Err(e))) => Err(e),
250 Some(Err(e)) => Err(ProxyError::SessionClosed(format!("task panic: {e}"))),
251 }
252 }
253
254 async fn connect_upstream(&self) -> Result<Transport, ProxyError> {
256 let timeout_secs = self.config.upstream_connect_timeout_secs;
257 if timeout_secs > 0 {
258 tokio::time::timeout(
259 std::time::Duration::from_secs(timeout_secs),
260 self.connect_upstream_inner(),
261 )
262 .await
263 .map_err(|_| {
264 ProxyError::UpstreamConnect(format!("connection timed out after {timeout_secs}s"))
265 })?
266 } else {
267 self.connect_upstream_inner().await
268 }
269 }
270
271 async fn connect_upstream_inner(&self) -> Result<Transport, ProxyError> {
272 match &self.config.upstream_transport {
273 UpstreamTransportType::Quic => self.connect_upstream_quic().await,
274 #[cfg(feature = "webtransport")]
275 UpstreamTransportType::WebTransport { url } => {
276 let url = url.clone();
277 self.connect_upstream_webtransport(&url).await
278 }
279 #[cfg(not(feature = "webtransport"))]
280 UpstreamTransportType::WebTransport { .. } => {
281 Err(ProxyError::UpstreamConnect("webtransport feature not enabled".to_string()))
282 }
283 }
284 }
285
286 async fn connect_upstream_quic(&self) -> Result<Transport, ProxyError> {
288 let server_addr =
289 self.config.upstream_addr.parse().map_err(|e: std::net::AddrParseError| {
290 ProxyError::UpstreamConnect(e.to_string())
291 })?;
292
293 let mut tls_config = self.build_upstream_tls_config()?;
294 tls_config.alpn_protocols = self.config.upstream_alpn(&self.client_alpn);
295
296 let quic_config: quinn::crypto::rustls::QuicClientConfig =
297 tls_config.try_into().map_err(|e| ProxyError::TlsConfig(format!("{e}")))?;
298 let client_config = quinn::ClientConfig::new(Arc::new(quic_config));
299
300 let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
301 .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
302 endpoint.set_default_client_config(client_config);
303
304 let server_name =
305 self.config.upstream_addr.split(':').next().unwrap_or("localhost").to_string();
306
307 let conn = endpoint
308 .connect(server_addr, &server_name)
309 .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?
310 .await
311 .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
312
313 Ok(Transport::Quic(QuicTransport::new(conn)))
314 }
315
316 #[cfg(feature = "webtransport")]
318 async fn connect_upstream_webtransport(&self, url: &str) -> Result<Transport, ProxyError> {
319 use moqtap_client::transport::webtransport::WebTransportTransport;
320
321 let wt_config = if self.config.skip_upstream_cert_verify {
322 wtransport::ClientConfig::builder()
323 .with_bind_default()
324 .with_no_cert_validation()
325 .build()
326 } else {
327 wtransport::ClientConfig::builder().with_bind_default().with_native_certs().build()
328 };
329
330 let endpoint = wtransport::Endpoint::client(wt_config)
331 .map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
332
333 let connection =
334 endpoint.connect(url).await.map_err(|e| ProxyError::UpstreamConnect(e.to_string()))?;
335
336 Ok(Transport::WebTransport(WebTransportTransport::new(connection)))
337 }
338
339 fn build_upstream_tls_config(&self) -> Result<rustls::ClientConfig, ProxyError> {
341 if self.config.skip_upstream_cert_verify {
342 Ok(rustls::ClientConfig::builder()
343 .dangerous()
344 .with_custom_certificate_verifier(Arc::new(SkipVerification))
345 .with_no_client_auth())
346 } else {
347 let mut roots = rustls::RootCertStore::empty();
348 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
349 for der in &self.config.upstream_ca_certs {
350 roots
351 .add(rustls::pki_types::CertificateDer::from(der.clone()))
352 .map_err(|e| ProxyError::TlsConfig(format!("bad CA cert: {e}")))?;
353 }
354 Ok(rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth())
355 }
356 }
357}
358
359#[derive(Clone)]
363struct ForwardCtx {
364 session_id: SessionId,
365 draft: DraftVersion,
369 draft_is_fixed: bool,
372 observer: Arc<dyn ProxyObserver>,
373 hook: Arc<dyn ProxyHook>,
374 cancel: CancellationToken,
375 observer_enabled: bool,
379 control_mutation: bool,
384}
385
386impl ForwardCtx {
387 fn emit(&self, event: impl FnOnce() -> ProxyEvent) {
393 if self.observer_enabled {
394 self.observer.on_event(&event());
395 }
396 }
397}
398
399async fn forward_control_stream(
401 client: &Transport,
402 relay: &Transport,
403 ctx: &ForwardCtx,
404) -> Result<(), ProxyError> {
405 let (client_send, client_recv) = client.accept_bi().await?;
407 ctx.emit(|| ProxyEvent::BiStreamOpened {
408 session_id: ctx.session_id,
409 side: ProxySide::ClientToProxy,
410 });
411
412 let (relay_send, relay_recv) = relay.open_bi().await?;
414 ctx.emit(|| ProxyEvent::BiStreamOpened {
415 session_id: ctx.session_id,
416 side: ProxySide::ProxyToRelay,
417 });
418
419 let ctx1 = ForwardCtx { ..ctx.clone() };
421 let ctx2 = ForwardCtx { ..ctx.clone() };
422
423 let client_to_relay = tokio::spawn(async move {
424 pipe_control(client_recv, relay_send, ProxySide::ClientToProxy, &ctx1).await
425 });
426
427 let relay_to_client = tokio::spawn(async move {
428 pipe_control(relay_recv, client_send, ProxySide::RelayToProxy, &ctx2).await
429 });
430
431 tokio::select! {
432 r = client_to_relay => r.map_err(|e| ProxyError::SessionClosed(e.to_string()))?,
433 r = relay_to_client => r.map_err(|e| ProxyError::SessionClosed(e.to_string()))?,
434 _ = ctx.cancel.cancelled() => Ok(()),
435 }
436}
437
438async fn pipe_control(
450 recv: RecvStream,
451 send: SendStream,
452 side: ProxySide,
453 ctx: &ForwardCtx,
454) -> Result<(), ProxyError> {
455 if ctx.control_mutation {
456 pipe_control_mutating(recv, send, side, ctx).await
457 } else {
458 pipe_control_passthrough(recv, send, side, ctx).await
459 }
460}
461
462async fn pipe_control_passthrough(
468 mut recv: RecvStream,
469 mut send: SendStream,
470 side: ProxySide,
471 ctx: &ForwardCtx,
472) -> Result<(), ProxyError> {
473 let mut buf = [0u8; 8192];
474
475 let mut parser: Option<ControlStreamParser> =
476 if ctx.draft_is_fixed { Some(ControlStreamParser::new(ctx.draft)) } else { None };
477
478 const DETECT_BUF_MAX: usize = 64 * 1024;
479 let mut detect_buf = BytesMut::new();
480
481 loop {
482 tokio::select! {
483 result = recv.read(&mut buf) => {
484 match result? {
485 Some(n) => {
486 let data = &buf[..n];
487
488 send.write_all(data).await?;
490
491 if ctx.observer_enabled {
495 if let Some(ref mut p) = parser {
496 emit_parsed_frames(p, data, side, ctx);
497 } else {
498 detect_buf.extend_from_slice(data);
499 if let Some(detected) = detect_draft_from_setup(
500 &detect_buf,
501 side,
502 ctx.draft,
503 ) {
504 let mut p = ControlStreamParser::new(detected);
505 let buffered = detect_buf.split().freeze();
506 emit_parsed_frames(&mut p, &buffered, side, ctx);
507 parser = Some(p);
508 } else if detect_buf.len() >= DETECT_BUF_MAX {
509 ctx.emit(|| ProxyEvent::ParseError {
510 session_id: ctx.session_id,
511 side,
512 error: format!(
513 "control draft detection gave up after {} bytes; \
514 falling back to {}",
515 detect_buf.len(),
516 ctx.draft,
517 ),
518 });
519 parser = Some(ControlStreamParser::new(ctx.draft));
520 detect_buf = BytesMut::new();
521 }
522 }
523 }
524 }
525 None => {
526 ctx.emit(|| ProxyEvent::StreamClosed {
527 session_id: ctx.session_id,
528 side,
529 });
530 let _ = send.finish();
531 return Ok(());
532 }
533 }
534 }
535 _ = ctx.cancel.cancelled() => {
536 return Ok(());
537 }
538 }
539 }
540}
541
542async fn pipe_control_mutating(
551 mut recv: RecvStream,
552 mut send: SendStream,
553 side: ProxySide,
554 ctx: &ForwardCtx,
555) -> Result<(), ProxyError> {
556 let mut buf = [0u8; 8192];
557
558 let mut parser: Option<ControlStreamParser> =
561 if ctx.draft_is_fixed { Some(ControlStreamParser::new_capturing(ctx.draft)) } else { None };
562
563 const DETECT_BUF_MAX: usize = 64 * 1024;
564 let mut detect_buf = BytesMut::new();
565
566 loop {
567 tokio::select! {
568 result = recv.read(&mut buf) => {
569 match result? {
570 Some(n) => {
571 let data = &buf[..n];
572
573 match parser.as_mut() {
574 Some(p) => {
575 forward_mutated_frames(p, data, &mut send, side, ctx).await?;
576 }
577 None => {
578 detect_buf.extend_from_slice(data);
579 let new_parser = if let Some(detected) = detect_draft_from_setup(
580 &detect_buf,
581 side,
582 ctx.draft,
583 ) {
584 Some(ControlStreamParser::new_capturing(detected))
585 } else if detect_buf.len() >= DETECT_BUF_MAX {
586 ctx.emit(|| ProxyEvent::ParseError {
587 session_id: ctx.session_id,
588 side,
589 error: format!(
590 "control draft detection gave up after {} bytes; \
591 falling back to {}",
592 detect_buf.len(),
593 ctx.draft,
594 ),
595 });
596 Some(ControlStreamParser::new_capturing(ctx.draft))
597 } else {
598 None
600 };
601
602 if let Some(mut p) = new_parser {
603 let buffered = detect_buf.split().freeze();
604 forward_mutated_frames(
605 &mut p,
606 &buffered,
607 &mut send,
608 side,
609 ctx,
610 )
611 .await?;
612 parser = Some(p);
613 }
614 }
615 }
616 }
617 None => {
618 ctx.emit(|| ProxyEvent::StreamClosed {
619 session_id: ctx.session_id,
620 side,
621 });
622 let _ = send.finish();
623 return Ok(());
624 }
625 }
626 }
627 _ = ctx.cancel.cancelled() => {
628 return Ok(());
629 }
630 }
631 }
632}
633
634async fn forward_mutated_frames(
637 parser: &mut ControlStreamParser,
638 data: &[u8],
639 send: &mut SendStream,
640 side: ProxySide,
641 ctx: &ForwardCtx,
642) -> Result<(), ProxyError> {
643 let result = parser.feed(data);
644 if let ParseResult::Messages(frames) = result {
645 for frame in frames {
646 let raw = frame.raw_bytes.as_deref().expect("capturing parser must populate raw_bytes");
647 let replacement =
648 ctx.hook.on_control_message(ctx.session_id, side, &frame.message, raw);
649 let out: &[u8] = match replacement.as_deref() {
650 Some(bytes) => bytes,
651 None => raw,
652 };
653 send.write_all(out).await?;
654
655 if ctx.observer_enabled {
656 let event = if frame.message.is_setup() {
657 ProxyEvent::SetupMessage {
658 session_id: ctx.session_id,
659 side,
660 message: frame.message.clone(),
661 }
662 } else {
663 ProxyEvent::ControlMessage {
664 session_id: ctx.session_id,
665 side,
666 message: frame.message.clone(),
667 }
668 };
669 ctx.observer.on_event(&event);
670 }
671 }
672 }
673 Ok(())
674}
675
676fn emit_parsed_frames(
687 parser: &mut ControlStreamParser,
688 data: &[u8],
689 side: ProxySide,
690 ctx: &ForwardCtx,
691) {
692 if !ctx.observer_enabled {
693 return;
694 }
695 match parser.feed(data) {
696 ParseResult::Messages(frames) => {
697 for frame in &frames {
698 let event = if frame.message.is_setup() {
699 ProxyEvent::SetupMessage {
700 session_id: ctx.session_id,
701 side,
702 message: frame.message.clone(),
703 }
704 } else {
705 ProxyEvent::ControlMessage {
706 session_id: ctx.session_id,
707 side,
708 message: frame.message.clone(),
709 }
710 };
711 ctx.observer.on_event(&event);
712 }
713 }
714 ParseResult::NeedMore => {}
715 }
716}
717
718async fn forward_uni_streams(
720 source: &Transport,
721 dest: &Transport,
722 side: ProxySide,
723 ctx: &ForwardCtx,
724) -> Result<(), ProxyError> {
725 loop {
726 tokio::select! {
727 result = source.accept_uni() => {
728 let recv = result?;
729 ctx.emit(|| ProxyEvent::UniStreamOpened {
730 session_id: ctx.session_id,
731 side,
732 });
733
734 let send = dest.open_uni().await?;
735 let ctx = ctx.clone();
736
737 tokio::spawn(async move {
738 if let Err(e) = pipe_data(
739 recv, send, side, &ctx,
740 )
741 .await
742 {
743 ctx.emit(|| ProxyEvent::ParseError {
744 session_id: ctx.session_id,
745 side,
746 error: format!("uni stream pipe: {e}"),
747 });
748 }
749 });
750 }
751 _ = ctx.cancel.cancelled() => {
752 return Ok(());
753 }
754 }
755 }
756}
757
758fn detect_stream_type(first_byte: u8) -> DataStreamType {
767 match first_byte {
770 0x05 => DataStreamType::Fetch,
771 _ => DataStreamType::Subgroup,
773 }
774}
775
776async fn pipe_data(
778 mut recv: RecvStream,
779 mut send: SendStream,
780 side: ProxySide,
781 ctx: &ForwardCtx,
782) -> Result<(), ProxyError> {
783 let mut buf = [0u8; 8192];
784 let mut parser: Option<DataStreamParser> = None;
785
786 loop {
787 tokio::select! {
788 result = recv.read(&mut buf) => {
789 match result? {
790 Some(n) => {
791 let data = &buf[..n];
792
793 if ctx.observer_enabled {
796 let feed_data = if parser.is_none() && !data.is_empty() {
802 let stream_type = detect_stream_type(data[0]);
803 parser = Some(DataStreamParser::new(
804 stream_type, ctx.draft,
805 ));
806 let skip = varint_len(data[0]);
809 &data[skip.min(data.len())..]
810 } else {
811 data
812 };
813
814 if let Some(ref mut p) = parser {
815 let results = p.feed(feed_data);
816 for result in &results {
817 match result {
818 crate::parser::data::DataParseResult::Header(header) => {
819 ctx.observer.on_event(
820 &ProxyEvent::DataStreamHeader {
821 session_id: ctx.session_id,
822 side,
823 header: header.clone(),
824 },
825 );
826 }
827 crate::parser::data::DataParseResult::Object(header) => {
828 ctx.observer.on_event(
829 &ProxyEvent::ObjectHeader {
830 session_id: ctx.session_id,
831 side,
832 header: header.clone(),
833 },
834 );
835 }
836 crate::parser::data::DataParseResult::Error(e) => {
837 ctx.observer.on_event(
838 &ProxyEvent::ParseError {
839 session_id: ctx.session_id,
840 side,
841 error: e.clone(),
842 },
843 );
844 }
845 crate::parser::data::DataParseResult::NeedMore => {}
846 }
847 }
848 }
849 }
850
851 send.write_all(data).await?;
854 }
855 None => {
856 ctx.emit(|| ProxyEvent::StreamClosed {
857 session_id: ctx.session_id,
858 side,
859 });
860 let _ = send.finish();
861 return Ok(());
862 }
863 }
864 }
865 _ = ctx.cancel.cancelled() => {
866 return Ok(());
867 }
868 }
869 }
870}
871
872async fn forward_datagrams(
874 source: &Transport,
875 dest: &Transport,
876 side: ProxySide,
877 ctx: &ForwardCtx,
878) -> Result<(), ProxyError> {
879 loop {
880 tokio::select! {
881 result = source.recv_datagram() => {
882 let data = result?;
883
884 if ctx.observer_enabled {
887 let mut cursor = &data[..];
888 if let Ok(header) = AnyDatagramHeader::decode(ctx.draft, &mut cursor) {
889 let payload_len = cursor.len();
890 ctx.observer.on_event(&ProxyEvent::Datagram {
891 session_id: ctx.session_id,
892 side,
893 header: header.clone(),
894 payload_len,
895 });
896
897 if let Some(replacement) = ctx.hook.on_datagram(
898 ctx.session_id, side, &header, &data,
899 ) {
900 dest.send_datagram(Bytes::from(replacement))?;
901 } else {
902 dest.send_datagram(data)?;
903 }
904 } else {
905 dest.send_datagram(data)?;
906 }
907 } else {
908 dest.send_datagram(data)?;
909 }
910 }
911 _ = ctx.cancel.cancelled() => {
912 return Ok(());
913 }
914 }
915 }
916}
917
918fn varint_len(first_byte: u8) -> usize {
920 1 << (first_byte >> 6)
921}
922
923fn detect_draft_from_setup(
943 buf: &[u8],
944 side: ProxySide,
945 fallback: DraftVersion,
946) -> Option<DraftVersion> {
947 let _ = fallback;
948 if buf.is_empty() {
949 return None;
950 }
951
952 let type_len = varint_len(buf[0]);
956 if buf.len() < type_len {
957 return None;
958 }
959 let mut cur = &buf[..type_len];
960 let type_id = VarInt::decode(&mut cur).ok()?.into_inner();
961
962 let (is_client_setup, is_server_setup, uses_u16_length) = match type_id {
968 0x40 => (true, false, false),
969 0x41 => (false, true, false),
970 0x20 => (true, false, true),
971 0x21 => (false, true, true),
972 _ => return None,
973 };
974
975 match side {
979 ProxySide::ClientToProxy | ProxySide::ProxyToRelay if !is_client_setup => return None,
980 ProxySide::RelayToProxy | ProxySide::ProxyToClient if !is_server_setup => return None,
981 _ => {}
982 }
983
984 let (payload_start, payload_len) = if uses_u16_length {
985 if buf.len() < type_len + 2 {
986 return None;
987 }
988 let len = ((buf[type_len] as usize) << 8) | (buf[type_len + 1] as usize);
989 (type_len + 2, len)
990 } else {
991 if buf.len() <= type_len {
992 return None;
993 }
994 let vl = varint_len(buf[type_len]);
995 if buf.len() < type_len + vl {
996 return None;
997 }
998 let mut cur = &buf[type_len..type_len + vl];
999 let v = VarInt::decode(&mut cur).ok()?;
1000 (type_len + vl, v.into_inner() as usize)
1001 };
1002
1003 if buf.len() < payload_start + payload_len {
1004 return None;
1005 }
1006 let payload = &buf[payload_start..payload_start + payload_len];
1007
1008 if is_client_setup {
1009 let mut cur = payload;
1013 let count = VarInt::decode(&mut cur).ok()?.into_inner() as usize;
1014 let mut best: Option<DraftVersion> = None;
1015 for _ in 0..count {
1016 let v = VarInt::decode(&mut cur).ok()?.into_inner();
1017 if let Some(d) = version_varint_to_draft(v) {
1018 if (7..=14).contains(&d.number()) {
1019 best = Some(match best {
1020 Some(b) if b.number() >= d.number() => b,
1021 _ => d,
1022 });
1023 }
1024 }
1025 }
1026 best
1027 } else {
1028 let mut cur = payload;
1031 let v = VarInt::decode(&mut cur).ok()?.into_inner();
1032 let d = version_varint_to_draft(v)?;
1033 if (7..=14).contains(&d.number()) {
1034 Some(d)
1035 } else {
1036 None
1037 }
1038 }
1039}
1040
1041fn version_varint_to_draft(v: u64) -> Option<DraftVersion> {
1044 const BASE: u64 = 0xff000000;
1045 if !(BASE..=BASE + 255).contains(&v) {
1046 return None;
1047 }
1048 DraftVersion::from_number((v - BASE) as u8)
1049}
1050
1051#[derive(Debug)]
1053struct SkipVerification;
1054
1055impl rustls::client::danger::ServerCertVerifier for SkipVerification {
1056 fn verify_server_cert(
1057 &self,
1058 _end_entity: &rustls::pki_types::CertificateDer<'_>,
1059 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
1060 _server_name: &rustls::pki_types::ServerName<'_>,
1061 _ocsp_response: &[u8],
1062 _now: rustls::pki_types::UnixTime,
1063 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
1064 Ok(rustls::client::danger::ServerCertVerified::assertion())
1065 }
1066
1067 fn verify_tls12_signature(
1068 &self,
1069 _message: &[u8],
1070 _cert: &rustls::pki_types::CertificateDer<'_>,
1071 _dcs: &rustls::DigitallySignedStruct,
1072 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1073 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1074 }
1075
1076 fn verify_tls13_signature(
1077 &self,
1078 _message: &[u8],
1079 _cert: &rustls::pki_types::CertificateDer<'_>,
1080 _dcs: &rustls::DigitallySignedStruct,
1081 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
1082 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
1083 }
1084
1085 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
1086 vec![
1087 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
1088 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
1089 rustls::SignatureScheme::ED25519,
1090 rustls::SignatureScheme::RSA_PSS_SHA256,
1091 rustls::SignatureScheme::RSA_PSS_SHA384,
1092 rustls::SignatureScheme::RSA_PSS_SHA512,
1093 ]
1094 }
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099 use super::*;
1100
1101 fn encode_client_setup_d07(drafts: &[u8]) -> Vec<u8> {
1104 use moqtap_codec::draft07::message::{ClientSetup, ControlMessage};
1105 let mut supported = Vec::new();
1106 for &n in drafts {
1107 supported.push(VarInt::from_usize(0xff000000 + n as usize));
1108 }
1109 let msg = ControlMessage::ClientSetup(ClientSetup {
1110 supported_versions: supported,
1111 parameters: Vec::new(),
1112 });
1113 let mut buf = Vec::new();
1114 msg.encode(&mut buf).unwrap();
1115 buf
1116 }
1117
1118 fn encode_client_setup_d14(drafts: &[u8]) -> Vec<u8> {
1120 use moqtap_codec::draft14::message::{ClientSetup, ControlMessage};
1121 let mut supported = Vec::new();
1122 for &n in drafts {
1123 supported.push(VarInt::from_usize(0xff000000 + n as usize));
1124 }
1125 let msg = ControlMessage::ClientSetup(ClientSetup {
1126 supported_versions: supported,
1127 parameters: Vec::new(),
1128 });
1129 let mut buf = Vec::new();
1130 msg.encode(&mut buf).unwrap();
1131 buf
1132 }
1133
1134 fn encode_server_setup_d07(draft: u8) -> Vec<u8> {
1136 use moqtap_codec::draft07::message::{ControlMessage, ServerSetup};
1137 let msg = ControlMessage::ServerSetup(ServerSetup {
1138 selected_version: VarInt::from_usize(0xff000000 + draft as usize),
1139 parameters: Vec::new(),
1140 });
1141 let mut buf = Vec::new();
1142 msg.encode(&mut buf).unwrap();
1143 buf
1144 }
1145
1146 fn encode_server_setup_d14(draft: u8) -> Vec<u8> {
1148 use moqtap_codec::draft14::message::{ControlMessage, ServerSetup};
1149 let msg = ControlMessage::ServerSetup(ServerSetup {
1150 selected_version: VarInt::from_usize(0xff000000 + draft as usize),
1151 parameters: Vec::new(),
1152 });
1153 let mut buf = Vec::new();
1154 msg.encode(&mut buf).unwrap();
1155 buf
1156 }
1157
1158 #[test]
1159 fn detect_picks_highest_draft_from_07_10_varint_framing() {
1160 let bytes = encode_client_setup_d07(&[7, 9]);
1162 let d = detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14);
1163 assert_eq!(d, Some(DraftVersion::Draft09));
1164 }
1165
1166 #[test]
1167 fn detect_picks_highest_draft_from_11_14_u16_framing() {
1168 let bytes = encode_client_setup_d14(&[11, 13, 14]);
1170 let d = detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft11);
1171 assert_eq!(d, Some(DraftVersion::Draft14));
1172 }
1173
1174 #[test]
1175 fn detect_from_server_setup_varint_framing() {
1176 let bytes = encode_server_setup_d07(10);
1177 let d = detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft07);
1178 assert_eq!(d, Some(DraftVersion::Draft10));
1179 }
1180
1181 #[test]
1182 fn detect_from_server_setup_u16_framing() {
1183 let bytes = encode_server_setup_d14(14);
1184 let d = detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft11);
1185 assert_eq!(d, Some(DraftVersion::Draft14));
1186 }
1187
1188 #[test]
1189 fn detect_returns_none_on_partial_bytes() {
1190 let bytes = encode_client_setup_d14(&[14]);
1191 assert_eq!(
1193 detect_draft_from_setup(&bytes[..1], ProxySide::ClientToProxy, DraftVersion::Draft14),
1194 None
1195 );
1196 }
1197
1198 #[test]
1199 fn detect_returns_none_for_unrelated_first_byte() {
1200 let bytes = [0x10u8, 0x00, 0x00];
1202 assert_eq!(
1203 detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14),
1204 None
1205 );
1206 }
1207
1208 #[test]
1209 fn detect_ignores_15_plus_versions_in_moq_00_setup() {
1210 let bytes = encode_client_setup_d14(&[15]);
1214 assert_eq!(
1215 detect_draft_from_setup(&bytes, ProxySide::ClientToProxy, DraftVersion::Draft14),
1216 None
1217 );
1218 }
1219
1220 #[test]
1221 fn detect_setup_wrong_direction_returns_none() {
1222 let bytes = encode_client_setup_d14(&[14]);
1224 assert_eq!(
1225 detect_draft_from_setup(&bytes, ProxySide::RelayToProxy, DraftVersion::Draft14),
1226 None
1227 );
1228 }
1229}