1use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use tokio_util::sync::CancellationToken;
7
8use crate::error::ProxyError;
9use crate::event::{ProxyEvent, SessionId};
10use crate::hook::{NoOpHook, ProxyHook};
11use crate::listener::{Listener, ListenerConfig};
12use crate::observer::ProxyObserver;
13use crate::session::{ProxySession, ProxySessionConfig};
14
15#[derive(Debug, Clone)]
17pub enum ListenerMode {
18 Quic,
20 WebTransport,
22}
23
24pub struct ProxyConfig {
26 pub listener: ListenerConfig,
28 pub session: ProxySessionConfig,
30 pub listener_mode: ListenerMode,
32}
33
34pub struct TransparentProxy {
40 config: ProxyConfig,
41 observer: Arc<dyn ProxyObserver>,
42 hook: Arc<dyn ProxyHook>,
43 cancel: CancellationToken,
44 next_session_id: AtomicU64,
45}
46
47impl TransparentProxy {
48 pub fn new(config: ProxyConfig, observer: Arc<dyn ProxyObserver>) -> Self {
50 Self {
51 config,
52 observer,
53 hook: Arc::new(NoOpHook),
54 cancel: CancellationToken::new(),
55 next_session_id: AtomicU64::new(1),
56 }
57 }
58
59 pub fn with_hook(
61 config: ProxyConfig,
62 observer: Arc<dyn ProxyObserver>,
63 hook: Arc<dyn ProxyHook>,
64 ) -> Self {
65 Self {
66 config,
67 observer,
68 hook,
69 cancel: CancellationToken::new(),
70 next_session_id: AtomicU64::new(1),
71 }
72 }
73
74 pub fn cancel_token(&self) -> CancellationToken {
76 self.cancel.clone()
77 }
78
79 pub async fn run(&self) -> Result<(), ProxyError> {
82 match self.config.listener_mode {
83 ListenerMode::Quic => self.run_quic().await,
84 #[cfg(feature = "webtransport")]
85 ListenerMode::WebTransport => self.run_webtransport().await,
86 #[cfg(not(feature = "webtransport"))]
87 ListenerMode::WebTransport => {
88 Err(ProxyError::Listener("webtransport feature not enabled".to_string()))
89 }
90 }
91 }
92
93 async fn run_quic(&self) -> Result<(), ProxyError> {
95 let listener = Listener::bind(ListenerConfig {
96 bind_addr: self.config.listener.bind_addr,
97 cert_chain: self.config.listener.cert_chain.clone(),
98 key_der: self.config.listener.key_der.clone_key(),
99 alpn: self.config.listener.alpn.clone(),
100 })?;
101
102 loop {
103 tokio::select! {
104 result = listener.accept() => {
105 let (conn, alpn) = result?;
106 let session_id = self.next_session_id();
107 let client_addr = conn.remote_address();
108 self.emit_session_started(session_id, client_addr);
109
110 let session = self.new_session(session_id, alpn);
111 tokio::spawn(async move {
112 let _ = session.run(conn).await;
113 });
114 }
115 _ = self.cancel.cancelled() => {
116 listener.close();
117 return Ok(());
118 }
119 }
120 }
121 }
122
123 #[cfg(feature = "webtransport")]
125 async fn run_webtransport(&self) -> Result<(), ProxyError> {
126 use crate::listener::{WtListener, WtListenerConfig};
127
128 let listener = WtListener::bind(WtListenerConfig {
129 bind_addr: self.config.listener.bind_addr,
130 cert_chain: self.config.listener.cert_chain.clone(),
131 key_der: self.config.listener.key_der.clone_key(),
132 })?;
133
134 loop {
135 tokio::select! {
136 result = listener.accept() => {
137 let conn = result?;
138 let session_id = self.next_session_id();
139 let client_addr = conn.remote_address();
140 self.emit_session_started(session_id, client_addr);
141
142 let session = self.new_session(session_id, Vec::new());
145 tokio::spawn(async move {
146 let _ = session.run_webtransport(conn).await;
147 });
148 }
149 _ = self.cancel.cancelled() => {
150 listener.close();
151 return Ok(());
152 }
153 }
154 }
155 }
156
157 fn next_session_id(&self) -> SessionId {
160 SessionId(self.next_session_id.fetch_add(1, Ordering::Relaxed))
161 }
162
163 fn emit_session_started(&self, session_id: SessionId, client_addr: std::net::SocketAddr) {
164 if self.observer.wants_events() {
165 self.observer.on_event(&ProxyEvent::SessionStarted { session_id, client_addr });
166 }
167 }
168
169 fn new_session(&self, session_id: SessionId, client_alpn: Vec<u8>) -> ProxySession {
170 ProxySession::new(
171 session_id,
172 ProxySessionConfig {
173 draft: self.config.session.draft,
174 upstream_transport: self.config.session.upstream_transport.clone(),
175 upstream_addr: self.config.session.upstream_addr.clone(),
176 skip_upstream_cert_verify: self.config.session.skip_upstream_cert_verify,
177 upstream_ca_certs: self.config.session.upstream_ca_certs.clone(),
178 upstream_connect_timeout_secs: self.config.session.upstream_connect_timeout_secs,
179 },
180 client_alpn,
181 Arc::clone(&self.observer),
182 Arc::clone(&self.hook),
183 self.cancel.child_token(),
184 )
185 }
186}