1use std::collections::HashMap;
2
3use crate::draft11::fetch::{FetchError, FetchStateMachine};
4use crate::draft11::namespace::{
5 AnnounceStateMachine, NamespaceError, SubscribeAnnouncesStateMachine,
6};
7use crate::draft11::session::request_id::{RequestIdAllocator, RequestIdError};
8use crate::draft11::session::setup::{self, SetupError};
9use crate::draft11::session::state::{SessionError, SessionState, SessionStateMachine};
10use crate::draft11::subscription::{SubscriptionError, SubscriptionStateMachine};
11use crate::draft11::track_status::{TrackStatusError, TrackStatusStateMachine};
12use moqtap_codec::draft11::message::{
13 self, Announce, AnnounceCancel, AnnounceError, AnnounceOk, ClientSetup, ControlMessage, Fetch,
14 FetchCancel, FetchPayload, FetchType, GoAway, MaxRequestId, RequestsBlocked, ServerSetup,
15 Subscribe, SubscribeAnnounces, SubscribeAnnouncesError, SubscribeAnnouncesOk, SubscribeDone,
16 SubscribeError, SubscribeOk, SubscribeUpdate, TrackStatus, TrackStatusRequest, Unannounce,
17 Unsubscribe, UnsubscribeAnnounces,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23type NamespaceKey = Vec<Vec<u8>>;
25
26#[derive(Debug, thiserror::Error)]
28pub enum EndpointError {
29 #[error("session error: {0}")]
31 Session(#[from] SessionError),
32 #[error("request ID error: {0}")]
34 RequestId(#[from] RequestIdError),
35 #[error("subscription error: {0}")]
37 Subscription(#[from] SubscriptionError),
38 #[error("fetch error: {0}")]
40 Fetch(#[from] FetchError),
41 #[error("namespace error: {0}")]
43 Namespace(#[from] NamespaceError),
44 #[error("track status error: {0}")]
46 TrackStatus(#[from] TrackStatusError),
47 #[error("setup error: {0}")]
49 Setup(#[from] SetupError),
50 #[error("unknown request ID: {0}")]
52 UnknownRequest(u64),
53 #[error("unknown namespace")]
55 UnknownNamespace,
56 #[error("session not active")]
58 NotActive,
59 #[error("session is draining, no new requests allowed")]
61 Draining,
62}
63
64pub struct Endpoint {
68 session: SessionStateMachine,
69 request_ids: RequestIdAllocator,
70 advertised_max_id: u64,
72 subscriptions: HashMap<u64, SubscriptionStateMachine>,
73 fetches: HashMap<u64, FetchStateMachine>,
74 subscribe_announces: HashMap<u64, SubscribeAnnouncesStateMachine>,
75 announces: HashMap<u64, AnnounceStateMachine>,
76 announce_ids: HashMap<NamespaceKey, u64>,
79 subscribe_announces_ids: HashMap<NamespaceKey, u64>,
81 track_statuses: HashMap<u64, TrackStatusStateMachine>,
82 negotiated_version: Option<VarInt>,
83 offered_versions: Vec<VarInt>,
84 goaway_uri: Option<Vec<u8>>,
85 peer_reported_max_request_id: Option<VarInt>,
88}
89
90impl Default for Endpoint {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl Endpoint {
97 pub fn new() -> Self {
99 Self {
100 session: SessionStateMachine::new(),
101 request_ids: RequestIdAllocator::new(),
102 advertised_max_id: 0,
103 subscriptions: HashMap::new(),
104 fetches: HashMap::new(),
105 subscribe_announces: HashMap::new(),
106 announces: HashMap::new(),
107 announce_ids: HashMap::new(),
108 subscribe_announces_ids: HashMap::new(),
109 track_statuses: HashMap::new(),
110 negotiated_version: None,
111 offered_versions: Vec::new(),
112 goaway_uri: None,
113 peer_reported_max_request_id: None,
114 }
115 }
116
117 pub fn session_state(&self) -> SessionState {
121 self.session.state()
122 }
123
124 pub fn negotiated_version(&self) -> Option<VarInt> {
126 self.negotiated_version
127 }
128
129 pub fn goaway_uri(&self) -> Option<&[u8]> {
131 self.goaway_uri.as_deref()
132 }
133
134 pub fn is_blocked(&self) -> bool {
136 self.request_ids.is_blocked()
137 }
138
139 pub fn active_subscription_count(&self) -> usize {
141 self.subscriptions.len()
142 }
143
144 pub fn active_fetch_count(&self) -> usize {
146 self.fetches.len()
147 }
148
149 pub fn active_subscribe_announces_count(&self) -> usize {
151 self.subscribe_announces.len()
152 }
153
154 pub fn active_announce_count(&self) -> usize {
156 self.announces.len()
157 }
158
159 pub fn active_track_status_count(&self) -> usize {
161 self.track_statuses.len()
162 }
163
164 pub fn connect(&mut self) -> Result<(), EndpointError> {
168 self.session.on_connect()?;
169 Ok(())
170 }
171
172 pub fn close(&mut self) -> Result<(), EndpointError> {
174 self.session.on_close()?;
175 Ok(())
176 }
177
178 pub fn send_client_setup(
182 &mut self,
183 versions: Vec<VarInt>,
184 parameters: Vec<KeyValuePair>,
185 ) -> Result<ControlMessage, EndpointError> {
186 self.offered_versions = versions.clone();
187 let msg = ClientSetup { supported_versions: versions, parameters };
188 setup::validate_client_setup(&msg)?;
189 Ok(ControlMessage::ClientSetup(msg))
190 }
191
192 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
196 setup::validate_server_setup(msg)?;
197 let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
198 self.negotiated_version = Some(version);
199 self.session.on_setup_complete()?;
200 for param in &msg.parameters {
202 if param.key == VarInt::from_u64(0x02).unwrap() {
203 if let KvpValue::Varint(v) = ¶m.value {
204 self.request_ids.update_max(v.into_inner())?;
205 }
206 }
207 }
208 Ok(())
209 }
210
211 pub fn receive_client_setup_and_respond(
215 &mut self,
216 client_setup: &ClientSetup,
217 selected_version: VarInt,
218 ) -> Result<ControlMessage, EndpointError> {
219 setup::validate_client_setup(client_setup)?;
220 let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
221 self.negotiated_version = Some(version);
222 self.session.on_setup_complete()?;
223 let msg = ServerSetup { selected_version: version, parameters: vec![] };
224 Ok(ControlMessage::ServerSetup(msg))
225 }
226
227 pub fn receive_max_request_id(&mut self, msg: &MaxRequestId) -> Result<(), EndpointError> {
231 self.request_ids.update_max(msg.request_id.into_inner())?;
232 Ok(())
233 }
234
235 pub fn send_max_request_id(&mut self, max_id: VarInt) -> Result<ControlMessage, EndpointError> {
238 let new_val = max_id.into_inner();
239 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
240 return Err(EndpointError::RequestId(RequestIdError::Decreased(
241 self.advertised_max_id,
242 new_val,
243 )));
244 }
245 self.advertised_max_id = new_val;
246 Ok(ControlMessage::MaxRequestId(MaxRequestId { request_id: max_id }))
247 }
248
249 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
253 self.session.on_goaway()?;
254 self.goaway_uri = Some(msg.new_session_uri.clone());
255 Ok(())
256 }
257
258 fn require_active_or_err(&self) -> Result<(), EndpointError> {
261 match self.session.state() {
262 SessionState::Active => Ok(()),
263 SessionState::Draining => Err(EndpointError::Draining),
264 _ => Err(EndpointError::NotActive),
265 }
266 }
267
268 #[allow(clippy::too_many_arguments)]
273 pub fn subscribe(
274 &mut self,
275 track_alias: VarInt,
276 track_namespace: TrackNamespace,
277 track_name: Vec<u8>,
278 subscriber_priority: u8,
279 group_order: VarInt,
280 filter_type: VarInt,
281 ) -> Result<(VarInt, ControlMessage), EndpointError> {
282 self.require_active_or_err()?;
283 let req_id = self.request_ids.allocate()?;
284
285 let mut sm = SubscriptionStateMachine::new();
286 sm.on_subscribe_sent()?;
287 self.subscriptions.insert(req_id.into_inner(), sm);
288
289 let msg = ControlMessage::Subscribe(Subscribe {
290 request_id: req_id,
291 track_alias,
292 track_namespace,
293 track_name,
294 subscriber_priority,
295 group_order,
296 forward: VarInt::from_u64(1).unwrap(),
297 filter_type,
298 start_group: None,
299 start_object: None,
300 end_group: None,
301 parameters: vec![],
302 });
303 Ok((req_id, msg))
304 }
305
306 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
308 let id = msg.request_id.into_inner();
309 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
310 sm.on_subscribe_ok()?;
311 Ok(())
312 }
313
314 pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
316 let id = msg.request_id.into_inner();
317 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
318 sm.on_subscribe_error()?;
319 Ok(())
320 }
321
322 pub fn unsubscribe(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
324 let id = request_id.into_inner();
325 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
326 sm.on_unsubscribe()?;
327 Ok(ControlMessage::Unsubscribe(Unsubscribe { request_id }))
328 }
329
330 pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
332 let id = msg.request_id.into_inner();
333 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
334 sm.on_subscribe_update()?;
335 Ok(())
336 }
337
338 pub fn receive_subscribe_done(&mut self, msg: &SubscribeDone) -> Result<(), EndpointError> {
340 let id = msg.request_id.into_inner();
341 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
342 sm.on_subscribe_done()?;
343 Ok(())
344 }
345
346 #[allow(clippy::too_many_arguments)]
351 pub fn fetch(
352 &mut self,
353 track_namespace: TrackNamespace,
354 track_name: Vec<u8>,
355 subscriber_priority: u8,
356 group_order: VarInt,
357 start_group: VarInt,
358 start_object: VarInt,
359 end_group: VarInt,
360 end_object: VarInt,
361 ) -> Result<(VarInt, ControlMessage), EndpointError> {
362 self.require_active_or_err()?;
363 let req_id = self.request_ids.allocate()?;
364
365 let mut sm = FetchStateMachine::new();
366 sm.on_fetch_sent()?;
367 self.fetches.insert(req_id.into_inner(), sm);
368
369 let msg = ControlMessage::Fetch(Fetch {
370 request_id: req_id,
371 subscriber_priority,
372 group_order,
373 fetch_type: FetchType::Standalone,
374 fetch_payload: FetchPayload::Standalone {
375 track_namespace,
376 track_name,
377 start_group,
378 start_object,
379 end_group,
380 end_object,
381 },
382 parameters: vec![],
383 });
384 Ok((req_id, msg))
385 }
386
387 pub fn joining_fetch(
392 &mut self,
393 subscriber_priority: u8,
394 group_order: VarInt,
395 fetch_type: FetchType,
396 joining_subscribe_id: VarInt,
397 joining_start: VarInt,
398 ) -> Result<(VarInt, ControlMessage), EndpointError> {
399 self.require_active_or_err()?;
400 if !matches!(fetch_type, FetchType::RelativeJoining | FetchType::AbsoluteJoining) {
401 return Err(EndpointError::NotActive);
403 }
404 let req_id = self.request_ids.allocate()?;
405
406 let mut sm = FetchStateMachine::new();
407 sm.on_fetch_sent()?;
408 self.fetches.insert(req_id.into_inner(), sm);
409
410 let msg = ControlMessage::Fetch(Fetch {
411 request_id: req_id,
412 subscriber_priority,
413 group_order,
414 fetch_type,
415 fetch_payload: FetchPayload::Joining { joining_subscribe_id, joining_start },
416 parameters: vec![],
417 });
418 Ok((req_id, msg))
419 }
420
421 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
423 let id = msg.request_id.into_inner();
424 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
425 sm.on_fetch_ok()?;
426 Ok(())
427 }
428
429 pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
431 let id = msg.request_id.into_inner();
432 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
433 sm.on_fetch_error()?;
434 Ok(())
435 }
436
437 pub fn fetch_cancel(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
439 let id = request_id.into_inner();
440 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
441 sm.on_fetch_cancel()?;
442 Ok(ControlMessage::FetchCancel(FetchCancel { request_id }))
443 }
444
445 pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
447 let id = request_id.into_inner();
448 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
449 sm.on_stream_fin()?;
450 Ok(())
451 }
452
453 pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
455 let id = request_id.into_inner();
456 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
457 sm.on_stream_reset()?;
458 Ok(())
459 }
460
461 pub fn subscribe_announces(
466 &mut self,
467 track_namespace_prefix: TrackNamespace,
468 ) -> Result<(VarInt, ControlMessage), EndpointError> {
469 self.require_active_or_err()?;
470 let req_id = self.request_ids.allocate()?;
471 let key = track_namespace_prefix.0.clone();
472 let mut sm = SubscribeAnnouncesStateMachine::new();
473 sm.on_subscribe_announces_sent()?;
474 self.subscribe_announces.insert(req_id.into_inner(), sm);
475 self.subscribe_announces_ids.insert(key, req_id.into_inner());
476 Ok((
477 req_id,
478 ControlMessage::SubscribeAnnounces(SubscribeAnnounces {
479 request_id: req_id,
480 track_namespace_prefix,
481 parameters: vec![],
482 }),
483 ))
484 }
485
486 pub fn receive_subscribe_announces_ok(
488 &mut self,
489 msg: &SubscribeAnnouncesOk,
490 ) -> Result<(), EndpointError> {
491 let id = msg.request_id.into_inner();
492 let sm = self.subscribe_announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
493 sm.on_subscribe_announces_ok()?;
494 Ok(())
495 }
496
497 pub fn receive_subscribe_announces_error(
499 &mut self,
500 msg: &SubscribeAnnouncesError,
501 ) -> Result<(), EndpointError> {
502 let id = msg.request_id.into_inner();
503 let sm = self.subscribe_announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
504 sm.on_subscribe_announces_error()?;
505 Ok(())
506 }
507
508 pub fn unsubscribe_announces(
510 &mut self,
511 track_namespace_prefix: TrackNamespace,
512 ) -> Result<ControlMessage, EndpointError> {
513 let id = *self
514 .subscribe_announces_ids
515 .get(&track_namespace_prefix.0)
516 .ok_or(EndpointError::UnknownNamespace)?;
517 let sm = self.subscribe_announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
518 sm.on_unsubscribe_announces()?;
519 Ok(ControlMessage::UnsubscribeAnnounces(UnsubscribeAnnounces { track_namespace_prefix }))
520 }
521
522 pub fn announce(
527 &mut self,
528 track_namespace: TrackNamespace,
529 ) -> Result<(VarInt, ControlMessage), EndpointError> {
530 self.require_active_or_err()?;
531 let req_id = self.request_ids.allocate()?;
532 let key = track_namespace.0.clone();
533 let mut sm = AnnounceStateMachine::new();
534 sm.on_announce_sent()?;
535 self.announces.insert(req_id.into_inner(), sm);
536 self.announce_ids.insert(key, req_id.into_inner());
537 Ok((
538 req_id,
539 ControlMessage::Announce(Announce {
540 request_id: req_id,
541 track_namespace,
542 parameters: vec![],
543 }),
544 ))
545 }
546
547 pub fn receive_announce_ok(&mut self, msg: &AnnounceOk) -> Result<(), EndpointError> {
549 let id = msg.request_id.into_inner();
550 let sm = self.announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
551 sm.on_announce_ok()?;
552 Ok(())
553 }
554
555 pub fn receive_announce_error(&mut self, msg: &AnnounceError) -> Result<(), EndpointError> {
557 let id = msg.request_id.into_inner();
558 let sm = self.announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
559 sm.on_announce_error()?;
560 Ok(())
561 }
562
563 pub fn receive_announce_cancel(&mut self, msg: &AnnounceCancel) -> Result<(), EndpointError> {
565 let id = *self
566 .announce_ids
567 .get(&msg.track_namespace.0)
568 .ok_or(EndpointError::UnknownNamespace)?;
569 let sm = self.announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
570 sm.on_announce_cancel()?;
571 Ok(())
572 }
573
574 pub fn unannounce(
576 &mut self,
577 track_namespace: TrackNamespace,
578 ) -> Result<ControlMessage, EndpointError> {
579 let id =
580 *self.announce_ids.get(&track_namespace.0).ok_or(EndpointError::UnknownNamespace)?;
581 let sm = self.announces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
582 sm.on_unannounce()?;
583 Ok(ControlMessage::Unannounce(Unannounce { track_namespace }))
584 }
585
586 pub fn track_status_request(
591 &mut self,
592 track_namespace: TrackNamespace,
593 track_name: Vec<u8>,
594 ) -> Result<(VarInt, ControlMessage), EndpointError> {
595 self.require_active_or_err()?;
596 let req_id = self.request_ids.allocate()?;
597 let mut sm = TrackStatusStateMachine::new();
598 sm.on_track_status_request_sent()?;
599 self.track_statuses.insert(req_id.into_inner(), sm);
600 Ok((
601 req_id,
602 ControlMessage::TrackStatusRequest(TrackStatusRequest {
603 request_id: req_id,
604 track_namespace,
605 track_name,
606 parameters: vec![],
607 }),
608 ))
609 }
610
611 pub fn receive_track_status(&mut self, msg: &TrackStatus) -> Result<(), EndpointError> {
613 let id = msg.request_id.into_inner();
614 let sm = self.track_statuses.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
615 sm.on_track_status()?;
616 Ok(())
617 }
618
619 pub fn receive_requests_blocked(&mut self, msg: &RequestsBlocked) -> Result<(), EndpointError> {
628 self.peer_reported_max_request_id = Some(msg.maximum_request_id);
629 Ok(())
630 }
631
632 pub fn peer_reported_max_request_id(&self) -> Option<VarInt> {
635 self.peer_reported_max_request_id
636 }
637
638 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
642 match msg {
643 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
644 ControlMessage::MaxRequestId(ref m) => self.receive_max_request_id(m),
645 ControlMessage::RequestsBlocked(ref m) => self.receive_requests_blocked(m),
646 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
647 ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
648 ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
649 ControlMessage::SubscribeDone(ref m) => self.receive_subscribe_done(m),
650 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
651 ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
652 ControlMessage::SubscribeAnnouncesOk(ref m) => self.receive_subscribe_announces_ok(m),
653 ControlMessage::SubscribeAnnouncesError(ref m) => {
654 self.receive_subscribe_announces_error(m)
655 }
656 ControlMessage::AnnounceOk(ref m) => self.receive_announce_ok(m),
657 ControlMessage::AnnounceError(ref m) => self.receive_announce_error(m),
658 ControlMessage::AnnounceCancel(ref m) => self.receive_announce_cancel(m),
659 ControlMessage::TrackStatus(ref m) => self.receive_track_status(m),
660 _ => Ok(()),
661 }
662 }
663}