1use std::collections::HashMap;
2
3use crate::draft09::fetch::{FetchError, FetchStateMachine};
4use crate::draft09::namespace::{
5 AnnounceStateMachine, NamespaceError, SubscribeAnnouncesStateMachine,
6};
7use crate::draft09::session::setup::{self, SetupError};
8use crate::draft09::session::state::{SessionError, SessionState, SessionStateMachine};
9use crate::draft09::session::subscribe_id::{SubscribeIdAllocator, SubscribeIdError};
10use crate::draft09::subscription::{SubscriptionError, SubscriptionStateMachine};
11use crate::draft09::track_status::{TrackStatusError, TrackStatusStateMachine};
12use moqtap_codec::draft09::message::{
13 self, Announce, AnnounceCancel, AnnounceError, AnnounceOk, ClientSetup, ControlMessage, Fetch,
14 FetchCancel, FetchType, GoAway, MaxSubscribeId, ServerSetup, Subscribe, SubscribeAnnounces,
15 SubscribeAnnouncesError, SubscribeAnnouncesOk, SubscribeDone, SubscribeError, SubscribeOk,
16 SubscribeUpdate, SubscribesBlocked, TrackStatus, TrackStatusRequest, Unannounce, Unsubscribe,
17 UnsubscribeAnnounces,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23type NamespaceKey = Vec<Vec<u8>>;
25
26type TrackKey = (Vec<Vec<u8>>, Vec<u8>);
28
29#[derive(Debug, thiserror::Error)]
31pub enum EndpointError {
32 #[error("session error: {0}")]
34 Session(#[from] SessionError),
35 #[error("subscribe ID error: {0}")]
37 SubscribeId(#[from] SubscribeIdError),
38 #[error("subscription error: {0}")]
40 Subscription(#[from] SubscriptionError),
41 #[error("fetch error: {0}")]
43 Fetch(#[from] FetchError),
44 #[error("namespace error: {0}")]
46 Namespace(#[from] NamespaceError),
47 #[error("track status error: {0}")]
49 TrackStatus(#[from] TrackStatusError),
50 #[error("setup error: {0}")]
52 Setup(#[from] SetupError),
53 #[error("unknown subscribe ID: {0}")]
55 UnknownSubscribe(u64),
56 #[error("unknown namespace")]
58 UnknownNamespace,
59 #[error("unknown track status request")]
61 UnknownTrackStatus,
62 #[error("session not active")]
64 NotActive,
65 #[error("session is draining, no new requests allowed")]
67 Draining,
68}
69
70pub struct Endpoint {
74 session: SessionStateMachine,
75 subscribe_ids: SubscribeIdAllocator,
76 advertised_max_id: u64,
78 subscriptions: HashMap<u64, SubscriptionStateMachine>,
79 fetches: HashMap<u64, FetchStateMachine>,
80 subscribe_announces: HashMap<NamespaceKey, SubscribeAnnouncesStateMachine>,
81 announces: HashMap<NamespaceKey, AnnounceStateMachine>,
82 track_statuses: HashMap<TrackKey, TrackStatusStateMachine>,
83 negotiated_version: Option<VarInt>,
84 offered_versions: Vec<VarInt>,
85 goaway_uri: Option<Vec<u8>>,
86 peer_reported_max_subscribe_id: Option<VarInt>,
89}
90
91impl Default for Endpoint {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl Endpoint {
98 pub fn new() -> Self {
100 Self {
101 session: SessionStateMachine::new(),
102 subscribe_ids: SubscribeIdAllocator::new(),
103 advertised_max_id: 0,
104 subscriptions: HashMap::new(),
105 fetches: HashMap::new(),
106 subscribe_announces: HashMap::new(),
107 announces: HashMap::new(),
108 track_statuses: HashMap::new(),
109 negotiated_version: None,
110 offered_versions: Vec::new(),
111 goaway_uri: None,
112 peer_reported_max_subscribe_id: None,
113 }
114 }
115
116 pub fn session_state(&self) -> SessionState {
120 self.session.state()
121 }
122
123 pub fn negotiated_version(&self) -> Option<VarInt> {
125 self.negotiated_version
126 }
127
128 pub fn goaway_uri(&self) -> Option<&[u8]> {
130 self.goaway_uri.as_deref()
131 }
132
133 pub fn is_blocked(&self) -> bool {
135 self.subscribe_ids.is_blocked()
136 }
137
138 pub fn active_subscription_count(&self) -> usize {
140 self.subscriptions.len()
141 }
142
143 pub fn active_fetch_count(&self) -> usize {
145 self.fetches.len()
146 }
147
148 pub fn active_subscribe_announces_count(&self) -> usize {
150 self.subscribe_announces.len()
151 }
152
153 pub fn active_announce_count(&self) -> usize {
155 self.announces.len()
156 }
157
158 pub fn active_track_status_count(&self) -> usize {
160 self.track_statuses.len()
161 }
162
163 pub fn connect(&mut self) -> Result<(), EndpointError> {
167 self.session.on_connect()?;
168 Ok(())
169 }
170
171 pub fn close(&mut self) -> Result<(), EndpointError> {
173 self.session.on_close()?;
174 Ok(())
175 }
176
177 pub fn send_client_setup(
181 &mut self,
182 versions: Vec<VarInt>,
183 parameters: Vec<KeyValuePair>,
184 ) -> Result<ControlMessage, EndpointError> {
185 self.offered_versions = versions.clone();
186 let msg = ClientSetup { supported_versions: versions, parameters };
187 setup::validate_client_setup(&msg)?;
188 Ok(ControlMessage::ClientSetup(msg))
189 }
190
191 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
195 setup::validate_server_setup(msg)?;
196 let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
197 self.negotiated_version = Some(version);
198 self.session.on_setup_complete()?;
199 for param in &msg.parameters {
201 if param.key == VarInt::from_u64(0x02).unwrap() {
202 if let KvpValue::Varint(v) = ¶m.value {
203 self.subscribe_ids.update_max(v.into_inner())?;
204 }
205 }
206 }
207 Ok(())
208 }
209
210 pub fn receive_client_setup_and_respond(
214 &mut self,
215 client_setup: &ClientSetup,
216 selected_version: VarInt,
217 ) -> Result<ControlMessage, EndpointError> {
218 setup::validate_client_setup(client_setup)?;
219 let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
220 self.negotiated_version = Some(version);
221 self.session.on_setup_complete()?;
222 let msg = ServerSetup { selected_version: version, parameters: vec![] };
223 Ok(ControlMessage::ServerSetup(msg))
224 }
225
226 pub fn receive_max_subscribe_id(&mut self, msg: &MaxSubscribeId) -> Result<(), EndpointError> {
230 self.subscribe_ids.update_max(msg.subscribe_id.into_inner())?;
231 Ok(())
232 }
233
234 pub fn send_max_subscribe_id(
237 &mut self,
238 max_id: VarInt,
239 ) -> Result<ControlMessage, EndpointError> {
240 let new_val = max_id.into_inner();
241 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
242 return Err(EndpointError::SubscribeId(SubscribeIdError::Decreased(
243 self.advertised_max_id,
244 new_val,
245 )));
246 }
247 self.advertised_max_id = new_val;
248 Ok(ControlMessage::MaxSubscribeId(MaxSubscribeId { subscribe_id: max_id }))
249 }
250
251 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
255 self.session.on_goaway()?;
256 self.goaway_uri = Some(msg.new_session_uri.clone());
257 Ok(())
258 }
259
260 fn require_active_or_err(&self) -> Result<(), EndpointError> {
263 match self.session.state() {
264 SessionState::Active => Ok(()),
265 SessionState::Draining => Err(EndpointError::Draining),
266 _ => Err(EndpointError::NotActive),
267 }
268 }
269
270 #[allow(clippy::too_many_arguments)]
273 pub fn subscribe(
274 &mut self,
275 track_alias: VarInt,
276 track_namespace: TrackNamespace,
277 track_name: Vec<u8>,
278 subscriber_priority: u8,
279 group_order: GroupOrder,
280 filter_type: FilterType,
281 ) -> Result<(VarInt, ControlMessage), EndpointError> {
282 self.require_active_or_err()?;
283 let sub_id = self.subscribe_ids.allocate()?;
284
285 let mut sm = SubscriptionStateMachine::new();
286 sm.on_subscribe_sent()?;
287 self.subscriptions.insert(sub_id.into_inner(), sm);
288
289 let msg = ControlMessage::Subscribe(Subscribe {
290 subscribe_id: sub_id,
291 track_alias,
292 track_namespace,
293 track_name,
294 subscriber_priority,
295 group_order,
296 filter_type,
297 start_location: None,
298 end_group: None,
299 parameters: vec![],
300 });
301 Ok((sub_id, msg))
302 }
303
304 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
306 let id = msg.subscribe_id.into_inner();
307 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
308 sm.on_subscribe_ok()?;
309 Ok(())
310 }
311
312 pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
314 let id = msg.subscribe_id.into_inner();
315 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
316 sm.on_subscribe_error()?;
317 Ok(())
318 }
319
320 pub fn unsubscribe(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
322 let id = subscribe_id.into_inner();
323 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
324 sm.on_unsubscribe()?;
325 Ok(ControlMessage::Unsubscribe(Unsubscribe { subscribe_id }))
326 }
327
328 pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
330 let id = msg.subscribe_id.into_inner();
331 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
332 sm.on_subscribe_update()?;
333 Ok(())
334 }
335
336 pub fn receive_subscribe_done(&mut self, msg: &SubscribeDone) -> Result<(), EndpointError> {
338 let id = msg.subscribe_id.into_inner();
339 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
340 sm.on_subscribe_done()?;
341 Ok(())
342 }
343
344 #[allow(clippy::too_many_arguments)]
348 pub fn fetch(
349 &mut self,
350 track_namespace: TrackNamespace,
351 track_name: Vec<u8>,
352 subscriber_priority: u8,
353 group_order: GroupOrder,
354 start_group: VarInt,
355 start_object: VarInt,
356 end_group: VarInt,
357 end_object: VarInt,
358 ) -> Result<(VarInt, ControlMessage), EndpointError> {
359 self.require_active_or_err()?;
360 let sub_id = self.subscribe_ids.allocate()?;
361
362 let mut sm = FetchStateMachine::new();
363 sm.on_fetch_sent()?;
364 self.fetches.insert(sub_id.into_inner(), sm);
365
366 let msg = ControlMessage::Fetch(Fetch {
367 subscribe_id: sub_id,
368 subscriber_priority,
369 group_order,
370 fetch_type: FetchType::Standalone,
371 track_namespace: Some(track_namespace),
372 track_name: Some(track_name),
373 start_group: Some(start_group),
374 start_object: Some(start_object),
375 end_group: Some(end_group),
376 end_object: Some(end_object),
377 joining_subscribe_id: None,
378 preceding_group_offset: None,
379 parameters: vec![],
380 });
381 Ok((sub_id, msg))
382 }
383
384 pub fn joining_fetch(
388 &mut self,
389 subscriber_priority: u8,
390 group_order: GroupOrder,
391 joining_subscribe_id: VarInt,
392 preceding_group_offset: VarInt,
393 ) -> Result<(VarInt, ControlMessage), EndpointError> {
394 self.require_active_or_err()?;
395 let sub_id = self.subscribe_ids.allocate()?;
396
397 let mut sm = FetchStateMachine::new();
398 sm.on_fetch_sent()?;
399 self.fetches.insert(sub_id.into_inner(), sm);
400
401 let msg = ControlMessage::Fetch(Fetch {
402 subscribe_id: sub_id,
403 subscriber_priority,
404 group_order,
405 fetch_type: FetchType::Joining,
406 track_namespace: None,
407 track_name: None,
408 start_group: None,
409 start_object: None,
410 end_group: None,
411 end_object: None,
412 joining_subscribe_id: Some(joining_subscribe_id),
413 preceding_group_offset: Some(preceding_group_offset),
414 parameters: vec![],
415 });
416 Ok((sub_id, msg))
417 }
418
419 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
421 let id = msg.subscribe_id.into_inner();
422 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
423 sm.on_fetch_ok()?;
424 Ok(())
425 }
426
427 pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
429 let id = msg.subscribe_id.into_inner();
430 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
431 sm.on_fetch_error()?;
432 Ok(())
433 }
434
435 pub fn fetch_cancel(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
437 let id = subscribe_id.into_inner();
438 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
439 sm.on_fetch_cancel()?;
440 Ok(ControlMessage::FetchCancel(FetchCancel { subscribe_id }))
441 }
442
443 pub fn on_fetch_stream_fin(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
445 let id = subscribe_id.into_inner();
446 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
447 sm.on_stream_fin()?;
448 Ok(())
449 }
450
451 pub fn on_fetch_stream_reset(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
453 let id = subscribe_id.into_inner();
454 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
455 sm.on_stream_reset()?;
456 Ok(())
457 }
458
459 pub fn subscribe_announces(
463 &mut self,
464 track_namespace_prefix: TrackNamespace,
465 ) -> Result<ControlMessage, EndpointError> {
466 self.require_active_or_err()?;
467 let key = track_namespace_prefix.0.clone();
468 let mut sm = SubscribeAnnouncesStateMachine::new();
469 sm.on_subscribe_announces_sent()?;
470 self.subscribe_announces.insert(key, sm);
471 Ok(ControlMessage::SubscribeAnnounces(SubscribeAnnounces {
472 track_namespace_prefix,
473 parameters: vec![],
474 }))
475 }
476
477 pub fn receive_subscribe_announces_ok(
479 &mut self,
480 msg: &SubscribeAnnouncesOk,
481 ) -> Result<(), EndpointError> {
482 let sm = self
483 .subscribe_announces
484 .get_mut(&msg.track_namespace_prefix.0)
485 .ok_or(EndpointError::UnknownNamespace)?;
486 sm.on_subscribe_announces_ok()?;
487 Ok(())
488 }
489
490 pub fn receive_subscribe_announces_error(
492 &mut self,
493 msg: &SubscribeAnnouncesError,
494 ) -> Result<(), EndpointError> {
495 let sm = self
496 .subscribe_announces
497 .get_mut(&msg.track_namespace_prefix.0)
498 .ok_or(EndpointError::UnknownNamespace)?;
499 sm.on_subscribe_announces_error()?;
500 Ok(())
501 }
502
503 pub fn unsubscribe_announces(
505 &mut self,
506 track_namespace_prefix: TrackNamespace,
507 ) -> Result<ControlMessage, EndpointError> {
508 let sm = self
509 .subscribe_announces
510 .get_mut(&track_namespace_prefix.0)
511 .ok_or(EndpointError::UnknownNamespace)?;
512 sm.on_unsubscribe_announces()?;
513 Ok(ControlMessage::UnsubscribeAnnounces(UnsubscribeAnnounces { track_namespace_prefix }))
514 }
515
516 pub fn announce(
520 &mut self,
521 track_namespace: TrackNamespace,
522 ) -> Result<ControlMessage, EndpointError> {
523 self.require_active_or_err()?;
524 let key = track_namespace.0.clone();
525 let mut sm = AnnounceStateMachine::new();
526 sm.on_announce_sent()?;
527 self.announces.insert(key, sm);
528 Ok(ControlMessage::Announce(Announce { track_namespace, parameters: vec![] }))
529 }
530
531 pub fn receive_announce_ok(&mut self, msg: &AnnounceOk) -> Result<(), EndpointError> {
533 let sm = self
534 .announces
535 .get_mut(&msg.track_namespace.0)
536 .ok_or(EndpointError::UnknownNamespace)?;
537 sm.on_announce_ok()?;
538 Ok(())
539 }
540
541 pub fn receive_announce_error(&mut self, msg: &AnnounceError) -> Result<(), EndpointError> {
543 let sm = self
544 .announces
545 .get_mut(&msg.track_namespace.0)
546 .ok_or(EndpointError::UnknownNamespace)?;
547 sm.on_announce_error()?;
548 Ok(())
549 }
550
551 pub fn receive_announce_cancel(&mut self, msg: &AnnounceCancel) -> Result<(), EndpointError> {
553 let sm = self
554 .announces
555 .get_mut(&msg.track_namespace.0)
556 .ok_or(EndpointError::UnknownNamespace)?;
557 sm.on_announce_cancel()?;
558 Ok(())
559 }
560
561 pub fn unannounce(
563 &mut self,
564 track_namespace: TrackNamespace,
565 ) -> Result<ControlMessage, EndpointError> {
566 let sm =
567 self.announces.get_mut(&track_namespace.0).ok_or(EndpointError::UnknownNamespace)?;
568 sm.on_unannounce()?;
569 Ok(ControlMessage::Unannounce(Unannounce { track_namespace }))
570 }
571
572 pub fn track_status_request(
576 &mut self,
577 track_namespace: TrackNamespace,
578 track_name: Vec<u8>,
579 ) -> Result<ControlMessage, EndpointError> {
580 self.require_active_or_err()?;
581 let key = (track_namespace.0.clone(), track_name.clone());
582 let mut sm = TrackStatusStateMachine::new();
583 sm.on_track_status_request_sent()?;
584 self.track_statuses.insert(key, sm);
585 Ok(ControlMessage::TrackStatusRequest(TrackStatusRequest { track_namespace, track_name }))
586 }
587
588 pub fn receive_track_status(&mut self, msg: &TrackStatus) -> Result<(), EndpointError> {
590 let key = (msg.track_namespace.0.clone(), msg.track_name.clone());
591 let sm = self.track_statuses.get_mut(&key).ok_or(EndpointError::UnknownTrackStatus)?;
592 sm.on_track_status()?;
593 Ok(())
594 }
595
596 pub fn receive_subscribes_blocked(
605 &mut self,
606 msg: &SubscribesBlocked,
607 ) -> Result<(), EndpointError> {
608 self.peer_reported_max_subscribe_id = Some(msg.maximum_subscribe_id);
609 Ok(())
610 }
611
612 pub fn peer_reported_max_subscribe_id(&self) -> Option<VarInt> {
615 self.peer_reported_max_subscribe_id
616 }
617
618 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
622 match msg {
623 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
624 ControlMessage::MaxSubscribeId(ref m) => self.receive_max_subscribe_id(m),
625 ControlMessage::SubscribesBlocked(ref m) => self.receive_subscribes_blocked(m),
626 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
627 ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
628 ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
629 ControlMessage::SubscribeDone(ref m) => self.receive_subscribe_done(m),
630 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
631 ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
632 ControlMessage::SubscribeAnnouncesOk(ref m) => self.receive_subscribe_announces_ok(m),
633 ControlMessage::SubscribeAnnouncesError(ref m) => {
634 self.receive_subscribe_announces_error(m)
635 }
636 ControlMessage::AnnounceOk(ref m) => self.receive_announce_ok(m),
637 ControlMessage::AnnounceError(ref m) => self.receive_announce_error(m),
638 ControlMessage::AnnounceCancel(ref m) => self.receive_announce_cancel(m),
639 ControlMessage::TrackStatus(ref m) => self.receive_track_status(m),
640 _ => Ok(()),
641 }
642 }
643}