Skip to main content

moqtap_client/
endpoint.rs

1use std::collections::HashMap;
2
3use moqtap_codec::message::{
4    self, ClientSetup, ControlMessage, Fetch, FetchCancel, GoAway, MaxRequestIdMsg, PublishDone,
5    PublishNamespaceCancel, PublishNamespaceDone, PublishNamespaceError, PublishNamespaceMsg,
6    PublishNamespaceOk, ServerSetup, Subscribe, SubscribeError, SubscribeNamespace,
7    SubscribeNamespaceError, SubscribeNamespaceOk, SubscribeOk, Unsubscribe, UnsubscribeNamespace,
8};
9use moqtap_codec::types::*;
10use moqtap_codec::varint::VarInt;
11use moqtap_session::request_id::{RequestIdAllocator, RequestIdError, Role};
12use moqtap_session::setup::{self, SetupError};
13use moqtap_session::state::{SessionError, SessionState, SessionStateMachine};
14
15use crate::fetch::{FetchError, FetchStateMachine};
16use crate::namespace::{
17    NamespaceError, PublishNamespaceStateMachine, SubscribeNamespaceStateMachine,
18};
19use crate::subscription::{SubscriptionError, SubscriptionStateMachine};
20
21/// Errors that can occur during endpoint operations.
22#[derive(Debug, thiserror::Error)]
23pub enum EndpointError {
24    #[error("session error: {0}")]
25    Session(#[from] SessionError),
26    #[error("request ID error: {0}")]
27    RequestId(#[from] RequestIdError),
28    #[error("subscription error: {0}")]
29    Subscription(#[from] SubscriptionError),
30    #[error("fetch error: {0}")]
31    Fetch(#[from] FetchError),
32    #[error("namespace error: {0}")]
33    Namespace(#[from] NamespaceError),
34    #[error("setup error: {0}")]
35    Setup(#[from] SetupError),
36    #[error("unknown request ID: {0}")]
37    UnknownRequest(u64),
38    #[error("session not active")]
39    NotActive,
40    #[error("session is draining, no new requests allowed")]
41    Draining,
42}
43
44/// Unified MoQT endpoint wrapping session lifecycle, request ID allocation,
45/// and all per-request state machines (subscriptions, fetches, namespaces).
46pub struct Endpoint {
47    role: Role,
48    session: SessionStateMachine,
49    request_ids: RequestIdAllocator,
50    subscriptions: HashMap<u64, SubscriptionStateMachine>,
51    fetches: HashMap<u64, FetchStateMachine>,
52    subscribe_namespaces: HashMap<u64, SubscribeNamespaceStateMachine>,
53    publish_namespaces: HashMap<u64, PublishNamespaceStateMachine>,
54    negotiated_version: Option<VarInt>,
55    offered_versions: Vec<VarInt>,
56    goaway_uri: Option<Vec<u8>>,
57}
58
59impl Endpoint {
60    /// Create a new endpoint with the given role.
61    pub fn new(role: Role) -> Self {
62        Self {
63            role,
64            session: SessionStateMachine::new(),
65            request_ids: RequestIdAllocator::new(role),
66            subscriptions: HashMap::new(),
67            fetches: HashMap::new(),
68            subscribe_namespaces: HashMap::new(),
69            publish_namespaces: HashMap::new(),
70            negotiated_version: None,
71            offered_versions: Vec::new(),
72            goaway_uri: None,
73        }
74    }
75
76    // ── Accessors ──────────────────────────────────────────────
77
78    pub fn role(&self) -> Role {
79        self.role
80    }
81
82    pub fn session_state(&self) -> SessionState {
83        self.session.state()
84    }
85
86    pub fn negotiated_version(&self) -> Option<VarInt> {
87        self.negotiated_version
88    }
89
90    pub fn goaway_uri(&self) -> Option<&[u8]> {
91        self.goaway_uri.as_deref()
92    }
93
94    pub fn is_blocked(&self) -> bool {
95        self.request_ids.is_blocked()
96    }
97
98    pub fn active_subscription_count(&self) -> usize {
99        self.subscriptions.len()
100    }
101
102    pub fn active_fetch_count(&self) -> usize {
103        self.fetches.len()
104    }
105
106    pub fn active_subscribe_namespace_count(&self) -> usize {
107        self.subscribe_namespaces.len()
108    }
109
110    pub fn active_publish_namespace_count(&self) -> usize {
111        self.publish_namespaces.len()
112    }
113
114    // ── Session lifecycle ──────────────────────────────────────
115
116    /// Transition from Connecting to SetupExchange.
117    pub fn connect(&mut self) -> Result<(), EndpointError> {
118        self.session.on_connect()?;
119        Ok(())
120    }
121
122    /// Close the session (Active or Draining -> Closed).
123    pub fn close(&mut self) -> Result<(), EndpointError> {
124        self.session.on_close()?;
125        Ok(())
126    }
127
128    // ── Client setup ───────────────────────────────────────────
129
130    /// Generate a CLIENT_SETUP message (client-side).
131    pub fn send_client_setup(
132        &mut self,
133        versions: Vec<VarInt>,
134    ) -> Result<ControlMessage, EndpointError> {
135        self.offered_versions = versions.clone();
136        let msg = ClientSetup { supported_versions: versions, parameters: vec![] };
137        setup::validate_client_setup(&msg)?;
138        Ok(ControlMessage::ClientSetup(msg))
139    }
140
141    /// Process a SERVER_SETUP message (client-side). Transitions to Active.
142    pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
143        setup::validate_server_setup(msg)?;
144        let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
145        self.negotiated_version = Some(version);
146        self.session.on_setup_complete()?;
147        Ok(())
148    }
149
150    // ── Server setup ───────────────────────────────────────────
151
152    /// Process CLIENT_SETUP and generate SERVER_SETUP (server-side).
153    pub fn receive_client_setup_and_respond(
154        &mut self,
155        client_setup: &ClientSetup,
156        selected_version: VarInt,
157    ) -> Result<ControlMessage, EndpointError> {
158        setup::validate_client_setup(client_setup)?;
159        let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
160        self.negotiated_version = Some(version);
161        self.session.on_setup_complete()?;
162        let msg = ServerSetup { selected_version: version, parameters: vec![] };
163        Ok(ControlMessage::ServerSetup(msg))
164    }
165
166    // ── MAX_REQUEST_ID ─────────────────────────────────────────
167
168    /// Process an incoming MAX_REQUEST_ID message.
169    pub fn receive_max_request_id(&mut self, msg: &MaxRequestIdMsg) -> Result<(), EndpointError> {
170        self.request_ids.update_max(msg.request_id.into_inner())?;
171        Ok(())
172    }
173
174    /// Generate a MAX_REQUEST_ID message (typically server-side).
175    pub fn send_max_request_id(&mut self, max_id: VarInt) -> Result<ControlMessage, EndpointError> {
176        Ok(ControlMessage::MaxRequestId(MaxRequestIdMsg { request_id: max_id }))
177    }
178
179    // ── GoAway ─────────────────────────────────────────────────
180
181    /// Process an incoming GOAWAY message. Transitions to Draining.
182    pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
183        self.session.on_goaway()?;
184        self.goaway_uri = Some(msg.new_session_uri.clone());
185        Ok(())
186    }
187
188    // ── Subscribe flow ─────────────────────────────────────────
189
190    fn require_active_or_err(&self) -> Result<(), EndpointError> {
191        match self.session.state() {
192            SessionState::Active => Ok(()),
193            SessionState::Draining => Err(EndpointError::Draining),
194            _ => Err(EndpointError::NotActive),
195        }
196    }
197
198    /// Send a SUBSCRIBE message. Allocates a request ID and creates a
199    /// subscription state machine.
200    pub fn subscribe(
201        &mut self,
202        track_namespace: TrackNamespace,
203        track_name: Vec<u8>,
204        subscriber_priority: u8,
205        group_order: GroupOrder,
206        filter_type: FilterType,
207    ) -> Result<(VarInt, ControlMessage), EndpointError> {
208        self.require_active_or_err()?;
209        let req_id = self.request_ids.allocate()?;
210
211        let mut sm = SubscriptionStateMachine::new();
212        sm.on_subscribe_sent()?;
213        self.subscriptions.insert(req_id.into_inner(), sm);
214
215        let msg = ControlMessage::Subscribe(Subscribe {
216            request_id: req_id,
217            track_namespace,
218            track_name,
219            subscriber_priority,
220            group_order,
221            forward: Forward::Forward,
222            filter_type,
223            start_location: None,
224            end_group: None,
225            parameters: vec![],
226        });
227        Ok((req_id, msg))
228    }
229
230    /// Process an incoming SUBSCRIBE_OK.
231    pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
232        let id = msg.request_id.into_inner();
233        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
234        sm.on_subscribe_ok()?;
235        Ok(())
236    }
237
238    /// Process an incoming SUBSCRIBE_ERROR.
239    pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
240        let id = msg.request_id.into_inner();
241        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
242        sm.on_subscribe_error()?;
243        Ok(())
244    }
245
246    /// Send an UNSUBSCRIBE message for an active subscription.
247    pub fn unsubscribe(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
248        let id = request_id.into_inner();
249        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
250        sm.on_unsubscribe()?;
251        Ok(ControlMessage::Unsubscribe(Unsubscribe { request_id }))
252    }
253
254    /// Process an incoming PUBLISH_DONE.
255    pub fn receive_publish_done(&mut self, msg: &PublishDone) -> Result<(), EndpointError> {
256        let id = msg.request_id.into_inner();
257        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
258        sm.on_publish_done()?;
259        Ok(())
260    }
261
262    // ── Fetch flow ─────────────────────────────────────────────
263
264    /// Send a FETCH message. Allocates a request ID and creates a fetch state machine.
265    pub fn fetch(
266        &mut self,
267        track_namespace: TrackNamespace,
268        track_name: Vec<u8>,
269        start_group: VarInt,
270        start_object: VarInt,
271    ) -> Result<(VarInt, ControlMessage), EndpointError> {
272        self.require_active_or_err()?;
273        let req_id = self.request_ids.allocate()?;
274
275        let mut sm = FetchStateMachine::new();
276        sm.on_fetch_sent()?;
277        self.fetches.insert(req_id.into_inner(), sm);
278
279        let msg = ControlMessage::Fetch(Fetch {
280            request_id: req_id,
281            track_namespace,
282            track_name,
283            start_group,
284            start_object,
285            end_group: None,
286            priority: None,
287            parameters: vec![],
288        });
289        Ok((req_id, msg))
290    }
291
292    /// Process an incoming FETCH_OK.
293    pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
294        let id = msg.request_id.into_inner();
295        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
296        sm.on_fetch_ok()?;
297        Ok(())
298    }
299
300    /// Process an incoming FETCH_ERROR.
301    pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
302        let id = msg.request_id.into_inner();
303        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
304        sm.on_fetch_error()?;
305        Ok(())
306    }
307
308    /// Send a FETCH_CANCEL message.
309    pub fn fetch_cancel(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
310        let id = request_id.into_inner();
311        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
312        sm.on_fetch_cancel()?;
313        Ok(ControlMessage::FetchCancel(FetchCancel { request_id }))
314    }
315
316    /// Notify that a fetch data stream received FIN.
317    pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
318        let id = request_id.into_inner();
319        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
320        sm.on_stream_fin()?;
321        Ok(())
322    }
323
324    /// Notify that a fetch data stream was reset.
325    pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
326        let id = request_id.into_inner();
327        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
328        sm.on_stream_reset()?;
329        Ok(())
330    }
331
332    // ── Subscribe Namespace flow ───────────────────────────────
333
334    /// Send a SUBSCRIBE_NAMESPACE message.
335    pub fn subscribe_namespace(
336        &mut self,
337        track_namespace: TrackNamespace,
338    ) -> Result<(VarInt, ControlMessage), EndpointError> {
339        self.require_active_or_err()?;
340        let req_id = self.request_ids.allocate()?;
341
342        let mut sm = SubscribeNamespaceStateMachine::new();
343        sm.on_subscribe_namespace_sent()?;
344        self.subscribe_namespaces.insert(req_id.into_inner(), sm);
345
346        let msg = ControlMessage::SubscribeNamespace(SubscribeNamespace {
347            request_id: req_id,
348            track_namespace,
349            parameters: vec![],
350        });
351        Ok((req_id, msg))
352    }
353
354    /// Process an incoming SUBSCRIBE_NAMESPACE_OK.
355    pub fn receive_subscribe_namespace_ok(
356        &mut self,
357        msg: &SubscribeNamespaceOk,
358    ) -> Result<(), EndpointError> {
359        let id = msg.request_id.into_inner();
360        let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
361        sm.on_subscribe_namespace_ok()?;
362        Ok(())
363    }
364
365    /// Process an incoming SUBSCRIBE_NAMESPACE_ERROR.
366    pub fn receive_subscribe_namespace_error(
367        &mut self,
368        msg: &SubscribeNamespaceError,
369    ) -> Result<(), EndpointError> {
370        let id = msg.request_id.into_inner();
371        let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
372        sm.on_subscribe_namespace_error()?;
373        Ok(())
374    }
375
376    /// Send an UNSUBSCRIBE_NAMESPACE message.
377    pub fn unsubscribe_namespace(
378        &mut self,
379        request_id: VarInt,
380        track_namespace: TrackNamespace,
381    ) -> Result<ControlMessage, EndpointError> {
382        let id = request_id.into_inner();
383        let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
384        sm.on_unsubscribe_namespace()?;
385        Ok(ControlMessage::UnsubscribeNamespace(UnsubscribeNamespace {
386            request_id,
387            track_namespace,
388        }))
389    }
390
391    // ── Publish Namespace flow ─────────────────────────────────
392
393    /// Send a PUBLISH_NAMESPACE message.
394    pub fn publish_namespace(
395        &mut self,
396        track_namespace: TrackNamespace,
397    ) -> Result<(VarInt, ControlMessage), EndpointError> {
398        self.require_active_or_err()?;
399        let req_id = self.request_ids.allocate()?;
400
401        let mut sm = PublishNamespaceStateMachine::new();
402        sm.on_publish_namespace_sent()?;
403        self.publish_namespaces.insert(req_id.into_inner(), sm);
404
405        let msg = ControlMessage::PublishNamespace(PublishNamespaceMsg {
406            request_id: req_id,
407            track_namespace,
408            parameters: vec![],
409        });
410        Ok((req_id, msg))
411    }
412
413    /// Process an incoming PUBLISH_NAMESPACE_OK.
414    pub fn receive_publish_namespace_ok(
415        &mut self,
416        msg: &PublishNamespaceOk,
417    ) -> Result<(), EndpointError> {
418        let id = msg.request_id.into_inner();
419        let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
420        sm.on_publish_namespace_ok()?;
421        Ok(())
422    }
423
424    /// Process an incoming PUBLISH_NAMESPACE_ERROR.
425    pub fn receive_publish_namespace_error(
426        &mut self,
427        msg: &PublishNamespaceError,
428    ) -> Result<(), EndpointError> {
429        let id = msg.request_id.into_inner();
430        let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
431        sm.on_publish_namespace_error()?;
432        Ok(())
433    }
434
435    /// Process an incoming PUBLISH_NAMESPACE_DONE.
436    pub fn receive_publish_namespace_done(
437        &mut self,
438        msg: &PublishNamespaceDone,
439    ) -> Result<(), EndpointError> {
440        let id = msg.request_id.into_inner();
441        let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
442        sm.on_publish_namespace_done()?;
443        Ok(())
444    }
445
446    /// Send a PUBLISH_NAMESPACE_CANCEL message.
447    pub fn publish_namespace_cancel(
448        &mut self,
449        request_id: VarInt,
450        reason_phrase: Vec<u8>,
451    ) -> Result<ControlMessage, EndpointError> {
452        let id = request_id.into_inner();
453        let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
454        sm.on_publish_namespace_cancel()?;
455        Ok(ControlMessage::PublishNamespaceCancel(PublishNamespaceCancel {
456            request_id,
457            reason_phrase,
458        }))
459    }
460
461    // ── Unified message dispatch ───────────────────────────────
462
463    /// Dispatch an incoming control message to the appropriate handler.
464    pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
465        match msg {
466            ControlMessage::GoAway(ref m) => self.receive_goaway(m),
467            ControlMessage::MaxRequestId(ref m) => self.receive_max_request_id(m),
468            ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
469            ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
470            ControlMessage::PublishDone(ref m) => self.receive_publish_done(m),
471            ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
472            ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
473            ControlMessage::SubscribeNamespaceOk(ref m) => self.receive_subscribe_namespace_ok(m),
474            ControlMessage::SubscribeNamespaceError(ref m) => {
475                self.receive_subscribe_namespace_error(m)
476            }
477            ControlMessage::PublishNamespaceOk(ref m) => self.receive_publish_namespace_ok(m),
478            ControlMessage::PublishNamespaceError(ref m) => self.receive_publish_namespace_error(m),
479            ControlMessage::PublishNamespaceDone(ref m) => self.receive_publish_namespace_done(m),
480            _ => Ok(()),
481        }
482    }
483}