Skip to main content

moqtap_client/draft07/
endpoint.rs

1use std::collections::HashMap;
2
3use crate::draft07::fetch::{FetchError, FetchStateMachine};
4use crate::draft07::namespace::{
5    AnnounceStateMachine, NamespaceError, SubscribeAnnouncesStateMachine,
6};
7use crate::draft07::session::setup::{self, SetupError};
8use crate::draft07::session::state::{SessionError, SessionState, SessionStateMachine};
9use crate::draft07::session::subscribe_id::{SubscribeIdAllocator, SubscribeIdError};
10use crate::draft07::subscription::{SubscriptionError, SubscriptionStateMachine};
11use crate::draft07::track_status::{TrackStatusError, TrackStatusStateMachine};
12use moqtap_codec::draft07::message::{
13    self, Announce, AnnounceCancel, AnnounceError, AnnounceOk, ClientSetup, ControlMessage, Fetch,
14    FetchCancel, GoAway, MaxSubscribeId, ServerSetup, Subscribe, SubscribeAnnounces,
15    SubscribeAnnouncesError, SubscribeAnnouncesOk, SubscribeDone, SubscribeError, SubscribeOk,
16    SubscribeUpdate, TrackStatus, TrackStatusRequest, Unannounce, Unsubscribe,
17    UnsubscribeAnnounces,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23/// Key identifying a namespace (used for Announce / SubscribeAnnounces maps).
24type NamespaceKey = Vec<Vec<u8>>;
25
26/// Key identifying a track (namespace + track name).
27type TrackKey = (Vec<Vec<u8>>, Vec<u8>);
28
29/// Errors that can occur during draft-07 endpoint operations.
30#[derive(Debug, thiserror::Error)]
31pub enum EndpointError {
32    /// A session-level state machine error.
33    #[error("session error: {0}")]
34    Session(#[from] SessionError),
35    /// A subscribe ID allocation or validation error.
36    #[error("subscribe ID error: {0}")]
37    SubscribeId(#[from] SubscribeIdError),
38    /// A subscription state machine error.
39    #[error("subscription error: {0}")]
40    Subscription(#[from] SubscriptionError),
41    /// A fetch state machine error.
42    #[error("fetch error: {0}")]
43    Fetch(#[from] FetchError),
44    /// A namespace state machine error.
45    #[error("namespace error: {0}")]
46    Namespace(#[from] NamespaceError),
47    /// A track status state machine error.
48    #[error("track status error: {0}")]
49    TrackStatus(#[from] TrackStatusError),
50    /// A setup negotiation error.
51    #[error("setup error: {0}")]
52    Setup(#[from] SetupError),
53    /// The subscribe ID does not match any known state machine.
54    #[error("unknown subscribe ID: {0}")]
55    UnknownSubscribe(u64),
56    /// The track namespace does not match any known state machine.
57    #[error("unknown namespace")]
58    UnknownNamespace,
59    /// The (namespace, track) pair does not match any known track status request.
60    #[error("unknown track status request")]
61    UnknownTrackStatus,
62    /// The session is not in the Active state.
63    #[error("session not active")]
64    NotActive,
65    /// The session is draining and cannot accept new requests.
66    #[error("session is draining, no new requests allowed")]
67    Draining,
68}
69
70/// Unified draft-07 MoQT endpoint wrapping session lifecycle, subscribe ID
71/// allocation, and all per-flow state machines (subscriptions, fetches,
72/// announces, subscribe-announces, track statuses).
73pub struct Endpoint {
74    session: SessionStateMachine,
75    subscribe_ids: SubscribeIdAllocator,
76    /// Tracks the MAX_SUBSCRIBE_ID we have advertised to the peer.
77    advertised_max_id: u64,
78    subscriptions: HashMap<u64, SubscriptionStateMachine>,
79    fetches: HashMap<u64, FetchStateMachine>,
80    subscribe_announces: HashMap<NamespaceKey, SubscribeAnnouncesStateMachine>,
81    announces: HashMap<NamespaceKey, AnnounceStateMachine>,
82    track_statuses: HashMap<TrackKey, TrackStatusStateMachine>,
83    negotiated_version: Option<VarInt>,
84    offered_versions: Vec<VarInt>,
85    goaway_uri: Option<Vec<u8>>,
86}
87
88impl Default for Endpoint {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Endpoint {
95    /// Create a new draft-07 endpoint.
96    pub fn new() -> Self {
97        Self {
98            session: SessionStateMachine::new(),
99            subscribe_ids: SubscribeIdAllocator::new(),
100            advertised_max_id: 0,
101            subscriptions: HashMap::new(),
102            fetches: HashMap::new(),
103            subscribe_announces: HashMap::new(),
104            announces: HashMap::new(),
105            track_statuses: HashMap::new(),
106            negotiated_version: None,
107            offered_versions: Vec::new(),
108            goaway_uri: None,
109        }
110    }
111
112    // ── Accessors ──────────────────────────────────────────────
113
114    /// Returns the current session state.
115    pub fn session_state(&self) -> SessionState {
116        self.session.state()
117    }
118
119    /// Returns the negotiated MoQT version, if setup is complete.
120    pub fn negotiated_version(&self) -> Option<VarInt> {
121        self.negotiated_version
122    }
123
124    /// Returns the URI from a received GOAWAY message, if any.
125    pub fn goaway_uri(&self) -> Option<&[u8]> {
126        self.goaway_uri.as_deref()
127    }
128
129    /// Returns whether this endpoint is blocked on subscribe ID allocation.
130    pub fn is_blocked(&self) -> bool {
131        self.subscribe_ids.is_blocked()
132    }
133
134    /// Returns the number of active subscription state machines.
135    pub fn active_subscription_count(&self) -> usize {
136        self.subscriptions.len()
137    }
138
139    /// Returns the number of active fetch state machines.
140    pub fn active_fetch_count(&self) -> usize {
141        self.fetches.len()
142    }
143
144    /// Returns the number of active subscribe-announces state machines.
145    pub fn active_subscribe_announces_count(&self) -> usize {
146        self.subscribe_announces.len()
147    }
148
149    /// Returns the number of active announce state machines.
150    pub fn active_announce_count(&self) -> usize {
151        self.announces.len()
152    }
153
154    /// Returns the number of active track status state machines.
155    pub fn active_track_status_count(&self) -> usize {
156        self.track_statuses.len()
157    }
158
159    // ── Session lifecycle ──────────────────────────────────────
160
161    /// Transition from Connecting to SetupExchange.
162    pub fn connect(&mut self) -> Result<(), EndpointError> {
163        self.session.on_connect()?;
164        Ok(())
165    }
166
167    /// Close the session (Active or Draining -> Closed).
168    pub fn close(&mut self) -> Result<(), EndpointError> {
169        self.session.on_close()?;
170        Ok(())
171    }
172
173    // ── Client setup ───────────────────────────────────────────
174
175    /// Generate a CLIENT_SETUP message (client-side).
176    pub fn send_client_setup(
177        &mut self,
178        versions: Vec<VarInt>,
179        parameters: Vec<KeyValuePair>,
180    ) -> Result<ControlMessage, EndpointError> {
181        self.offered_versions = versions.clone();
182        let msg = ClientSetup { supported_versions: versions, parameters };
183        setup::validate_client_setup(&msg)?;
184        Ok(ControlMessage::ClientSetup(msg))
185    }
186
187    /// Process a SERVER_SETUP message (client-side). Transitions to Active.
188    /// If the server includes a MAX_SUBSCRIBE_ID parameter (key 0x02), the
189    /// subscribe ID allocator is initialized with that value.
190    pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
191        setup::validate_server_setup(msg)?;
192        let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
193        self.negotiated_version = Some(version);
194        self.session.on_setup_complete()?;
195        // Extract MAX_SUBSCRIBE_ID (key 0x02) from setup parameters if present
196        for param in &msg.parameters {
197            if param.key == VarInt::from_u64(0x02).unwrap() {
198                if let KvpValue::Varint(v) = &param.value {
199                    self.subscribe_ids.update_max(v.into_inner())?;
200                }
201            }
202        }
203        Ok(())
204    }
205
206    // ── Server setup ───────────────────────────────────────────
207
208    /// Process CLIENT_SETUP and generate SERVER_SETUP (server-side).
209    pub fn receive_client_setup_and_respond(
210        &mut self,
211        client_setup: &ClientSetup,
212        selected_version: VarInt,
213    ) -> Result<ControlMessage, EndpointError> {
214        setup::validate_client_setup(client_setup)?;
215        let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
216        self.negotiated_version = Some(version);
217        self.session.on_setup_complete()?;
218        let msg = ServerSetup { selected_version: version, parameters: vec![] };
219        Ok(ControlMessage::ServerSetup(msg))
220    }
221
222    // ── MAX_SUBSCRIBE_ID ───────────────────────────────────────
223
224    /// Process an incoming MAX_SUBSCRIBE_ID message.
225    pub fn receive_max_subscribe_id(&mut self, msg: &MaxSubscribeId) -> Result<(), EndpointError> {
226        self.subscribe_ids.update_max(msg.subscribe_id.into_inner())?;
227        Ok(())
228    }
229
230    /// Generate a MAX_SUBSCRIBE_ID message (typically server-side).
231    /// The value must strictly increase over previous sends.
232    pub fn send_max_subscribe_id(
233        &mut self,
234        max_id: VarInt,
235    ) -> Result<ControlMessage, EndpointError> {
236        let new_val = max_id.into_inner();
237        if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
238            return Err(EndpointError::SubscribeId(SubscribeIdError::Decreased(
239                self.advertised_max_id,
240                new_val,
241            )));
242        }
243        self.advertised_max_id = new_val;
244        Ok(ControlMessage::MaxSubscribeId(MaxSubscribeId { subscribe_id: max_id }))
245    }
246
247    // ── GoAway ─────────────────────────────────────────────────
248
249    /// Process an incoming GOAWAY message. Transitions to Draining.
250    pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
251        self.session.on_goaway()?;
252        self.goaway_uri = Some(msg.new_session_uri.clone());
253        Ok(())
254    }
255
256    // ── Subscribe flow ─────────────────────────────────────────
257
258    fn require_active_or_err(&self) -> Result<(), EndpointError> {
259        match self.session.state() {
260            SessionState::Active => Ok(()),
261            SessionState::Draining => Err(EndpointError::Draining),
262            _ => Err(EndpointError::NotActive),
263        }
264    }
265
266    /// Send a SUBSCRIBE message. Allocates a subscribe ID and creates a
267    /// subscription state machine.
268    #[allow(clippy::too_many_arguments)]
269    pub fn subscribe(
270        &mut self,
271        track_alias: VarInt,
272        track_namespace: TrackNamespace,
273        track_name: Vec<u8>,
274        subscriber_priority: u8,
275        group_order: GroupOrder,
276        filter_type: FilterType,
277    ) -> Result<(VarInt, ControlMessage), EndpointError> {
278        self.require_active_or_err()?;
279        let sub_id = self.subscribe_ids.allocate()?;
280
281        let mut sm = SubscriptionStateMachine::new();
282        sm.on_subscribe_sent()?;
283        self.subscriptions.insert(sub_id.into_inner(), sm);
284
285        let msg = ControlMessage::Subscribe(Subscribe {
286            subscribe_id: sub_id,
287            track_alias,
288            track_namespace,
289            track_name,
290            subscriber_priority,
291            group_order,
292            filter_type,
293            start_location: None,
294            end_group: None,
295            end_object: None,
296            parameters: vec![],
297        });
298        Ok((sub_id, msg))
299    }
300
301    /// Process an incoming SUBSCRIBE_OK.
302    pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
303        let id = msg.subscribe_id.into_inner();
304        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
305        sm.on_subscribe_ok()?;
306        Ok(())
307    }
308
309    /// Process an incoming SUBSCRIBE_ERROR.
310    pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
311        let id = msg.subscribe_id.into_inner();
312        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
313        sm.on_subscribe_error()?;
314        Ok(())
315    }
316
317    /// Send an UNSUBSCRIBE message for an active subscription.
318    pub fn unsubscribe(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
319        let id = subscribe_id.into_inner();
320        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
321        sm.on_unsubscribe()?;
322        Ok(ControlMessage::Unsubscribe(Unsubscribe { subscribe_id }))
323    }
324
325    /// Process an incoming SUBSCRIBE_UPDATE.
326    pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
327        let id = msg.subscribe_id.into_inner();
328        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
329        sm.on_subscribe_update()?;
330        Ok(())
331    }
332
333    /// Process an incoming SUBSCRIBE_DONE (subscriber side — publisher finished).
334    pub fn receive_subscribe_done(&mut self, msg: &SubscribeDone) -> Result<(), EndpointError> {
335        let id = msg.subscribe_id.into_inner();
336        let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
337        sm.on_subscribe_done()?;
338        Ok(())
339    }
340
341    // ── Fetch flow ─────────────────────────────────────────────
342
343    /// Send a FETCH message. Allocates a subscribe ID and creates a fetch state machine.
344    #[allow(clippy::too_many_arguments)]
345    pub fn fetch(
346        &mut self,
347        track_namespace: TrackNamespace,
348        track_name: Vec<u8>,
349        subscriber_priority: u8,
350        group_order: GroupOrder,
351        start_group: VarInt,
352        start_object: VarInt,
353        end_group: VarInt,
354        end_object: VarInt,
355    ) -> Result<(VarInt, ControlMessage), EndpointError> {
356        self.require_active_or_err()?;
357        let sub_id = self.subscribe_ids.allocate()?;
358
359        let mut sm = FetchStateMachine::new();
360        sm.on_fetch_sent()?;
361        self.fetches.insert(sub_id.into_inner(), sm);
362
363        let msg = ControlMessage::Fetch(Fetch {
364            subscribe_id: sub_id,
365            track_namespace,
366            track_name,
367            subscriber_priority,
368            group_order,
369            start_group,
370            start_object,
371            end_group,
372            end_object,
373            parameters: vec![],
374        });
375        Ok((sub_id, msg))
376    }
377
378    /// Process an incoming FETCH_OK.
379    pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
380        let id = msg.subscribe_id.into_inner();
381        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
382        sm.on_fetch_ok()?;
383        Ok(())
384    }
385
386    /// Process an incoming FETCH_ERROR.
387    pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
388        let id = msg.subscribe_id.into_inner();
389        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
390        sm.on_fetch_error()?;
391        Ok(())
392    }
393
394    /// Send a FETCH_CANCEL message.
395    pub fn fetch_cancel(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
396        let id = subscribe_id.into_inner();
397        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
398        sm.on_fetch_cancel()?;
399        Ok(ControlMessage::FetchCancel(FetchCancel { subscribe_id }))
400    }
401
402    /// Notify that a fetch data stream received FIN.
403    pub fn on_fetch_stream_fin(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
404        let id = subscribe_id.into_inner();
405        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
406        sm.on_stream_fin()?;
407        Ok(())
408    }
409
410    /// Notify that a fetch data stream was reset.
411    pub fn on_fetch_stream_reset(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
412        let id = subscribe_id.into_inner();
413        let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
414        sm.on_stream_reset()?;
415        Ok(())
416    }
417
418    // ── Subscribe Announces flow ───────────────────────────────
419
420    /// Send a SUBSCRIBE_ANNOUNCES message.
421    pub fn subscribe_announces(
422        &mut self,
423        track_namespace_prefix: TrackNamespace,
424    ) -> Result<ControlMessage, EndpointError> {
425        self.require_active_or_err()?;
426        let key = track_namespace_prefix.0.clone();
427        let mut sm = SubscribeAnnouncesStateMachine::new();
428        sm.on_subscribe_announces_sent()?;
429        self.subscribe_announces.insert(key, sm);
430        Ok(ControlMessage::SubscribeAnnounces(SubscribeAnnounces {
431            track_namespace_prefix,
432            parameters: vec![],
433        }))
434    }
435
436    /// Process an incoming SUBSCRIBE_ANNOUNCES_OK.
437    pub fn receive_subscribe_announces_ok(
438        &mut self,
439        msg: &SubscribeAnnouncesOk,
440    ) -> Result<(), EndpointError> {
441        let sm = self
442            .subscribe_announces
443            .get_mut(&msg.track_namespace_prefix.0)
444            .ok_or(EndpointError::UnknownNamespace)?;
445        sm.on_subscribe_announces_ok()?;
446        Ok(())
447    }
448
449    /// Process an incoming SUBSCRIBE_ANNOUNCES_ERROR.
450    pub fn receive_subscribe_announces_error(
451        &mut self,
452        msg: &SubscribeAnnouncesError,
453    ) -> Result<(), EndpointError> {
454        let sm = self
455            .subscribe_announces
456            .get_mut(&msg.track_namespace_prefix.0)
457            .ok_or(EndpointError::UnknownNamespace)?;
458        sm.on_subscribe_announces_error()?;
459        Ok(())
460    }
461
462    /// Send an UNSUBSCRIBE_ANNOUNCES message.
463    pub fn unsubscribe_announces(
464        &mut self,
465        track_namespace_prefix: TrackNamespace,
466    ) -> Result<ControlMessage, EndpointError> {
467        let sm = self
468            .subscribe_announces
469            .get_mut(&track_namespace_prefix.0)
470            .ok_or(EndpointError::UnknownNamespace)?;
471        sm.on_unsubscribe_announces()?;
472        Ok(ControlMessage::UnsubscribeAnnounces(UnsubscribeAnnounces { track_namespace_prefix }))
473    }
474
475    // ── Announce flow ──────────────────────────────────────────
476
477    /// Send an ANNOUNCE message.
478    pub fn announce(
479        &mut self,
480        track_namespace: TrackNamespace,
481    ) -> Result<ControlMessage, EndpointError> {
482        self.require_active_or_err()?;
483        let key = track_namespace.0.clone();
484        let mut sm = AnnounceStateMachine::new();
485        sm.on_announce_sent()?;
486        self.announces.insert(key, sm);
487        Ok(ControlMessage::Announce(Announce { track_namespace, parameters: vec![] }))
488    }
489
490    /// Process an incoming ANNOUNCE_OK.
491    pub fn receive_announce_ok(&mut self, msg: &AnnounceOk) -> Result<(), EndpointError> {
492        let sm = self
493            .announces
494            .get_mut(&msg.track_namespace.0)
495            .ok_or(EndpointError::UnknownNamespace)?;
496        sm.on_announce_ok()?;
497        Ok(())
498    }
499
500    /// Process an incoming ANNOUNCE_ERROR.
501    pub fn receive_announce_error(&mut self, msg: &AnnounceError) -> Result<(), EndpointError> {
502        let sm = self
503            .announces
504            .get_mut(&msg.track_namespace.0)
505            .ok_or(EndpointError::UnknownNamespace)?;
506        sm.on_announce_error()?;
507        Ok(())
508    }
509
510    /// Process an incoming ANNOUNCE_CANCEL.
511    pub fn receive_announce_cancel(&mut self, msg: &AnnounceCancel) -> Result<(), EndpointError> {
512        let sm = self
513            .announces
514            .get_mut(&msg.track_namespace.0)
515            .ok_or(EndpointError::UnknownNamespace)?;
516        sm.on_announce_cancel()?;
517        Ok(())
518    }
519
520    /// Send an UNANNOUNCE message (publisher withdrawing).
521    pub fn unannounce(
522        &mut self,
523        track_namespace: TrackNamespace,
524    ) -> Result<ControlMessage, EndpointError> {
525        let sm =
526            self.announces.get_mut(&track_namespace.0).ok_or(EndpointError::UnknownNamespace)?;
527        sm.on_unannounce()?;
528        Ok(ControlMessage::Unannounce(Unannounce { track_namespace }))
529    }
530
531    // ── Track Status flow ──────────────────────────────────────
532
533    /// Send a TRACK_STATUS_REQUEST message.
534    pub fn track_status_request(
535        &mut self,
536        track_namespace: TrackNamespace,
537        track_name: Vec<u8>,
538    ) -> Result<ControlMessage, EndpointError> {
539        self.require_active_or_err()?;
540        let key = (track_namespace.0.clone(), track_name.clone());
541        let mut sm = TrackStatusStateMachine::new();
542        sm.on_track_status_request_sent()?;
543        self.track_statuses.insert(key, sm);
544        Ok(ControlMessage::TrackStatusRequest(TrackStatusRequest { track_namespace, track_name }))
545    }
546
547    /// Process an incoming TRACK_STATUS reply.
548    pub fn receive_track_status(&mut self, msg: &TrackStatus) -> Result<(), EndpointError> {
549        let key = (msg.track_namespace.0.clone(), msg.track_name.clone());
550        let sm = self.track_statuses.get_mut(&key).ok_or(EndpointError::UnknownTrackStatus)?;
551        sm.on_track_status()?;
552        Ok(())
553    }
554
555    // ── Unified message dispatch ───────────────────────────────
556
557    /// Dispatch an incoming control message to the appropriate handler.
558    pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
559        match msg {
560            ControlMessage::GoAway(ref m) => self.receive_goaway(m),
561            ControlMessage::MaxSubscribeId(ref m) => self.receive_max_subscribe_id(m),
562            ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
563            ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
564            ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
565            ControlMessage::SubscribeDone(ref m) => self.receive_subscribe_done(m),
566            ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
567            ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
568            ControlMessage::SubscribeAnnouncesOk(ref m) => self.receive_subscribe_announces_ok(m),
569            ControlMessage::SubscribeAnnouncesError(ref m) => {
570                self.receive_subscribe_announces_error(m)
571            }
572            ControlMessage::AnnounceOk(ref m) => self.receive_announce_ok(m),
573            ControlMessage::AnnounceError(ref m) => self.receive_announce_error(m),
574            ControlMessage::AnnounceCancel(ref m) => self.receive_announce_cancel(m),
575            ControlMessage::TrackStatus(ref m) => self.receive_track_status(m),
576            _ => Ok(()),
577        }
578    }
579}