1use std::collections::HashMap;
2
3use crate::draft14::fetch::{FetchError, FetchStateMachine};
4use crate::draft14::namespace::{
5 NamespaceError, PublishNamespaceStateMachine, SubscribeNamespaceStateMachine,
6};
7use crate::draft14::publish::{PublishError as PublishFlowError, PublishStateMachine};
8use crate::draft14::session::request_id::{RequestIdAllocator, RequestIdError, Role};
9use crate::draft14::session::setup::{self, SetupError};
10use crate::draft14::session::state::{SessionError, SessionState, SessionStateMachine};
11use crate::draft14::subscription::{SubscriptionError, SubscriptionStateMachine};
12use crate::draft14::track_status::{TrackStatusError, TrackStatusStateMachine};
13use moqtap_codec::draft14::message::{
14 self, ClientSetup, ControlMessage, Fetch, FetchCancel, GoAway, MaxRequestId, PublishDone,
15 PublishNamespace, PublishNamespaceCancel, PublishNamespaceDone, PublishNamespaceError,
16 PublishNamespaceOk, RequestsBlocked, ServerSetup, Subscribe, SubscribeError,
17 SubscribeNamespace, SubscribeNamespaceError, SubscribeNamespaceOk, SubscribeOk,
18 SubscribeUpdate, Unsubscribe, UnsubscribeNamespace,
19};
20use moqtap_codec::kvp::{KeyValuePair, KvpValue};
21use moqtap_codec::types::*;
22use moqtap_codec::varint::VarInt;
23
24#[derive(Debug, thiserror::Error)]
26pub enum EndpointError {
27 #[error("session error: {0}")]
29 Session(#[from] SessionError),
30 #[error("request ID error: {0}")]
32 RequestId(#[from] RequestIdError),
33 #[error("subscription error: {0}")]
35 Subscription(#[from] SubscriptionError),
36 #[error("fetch error: {0}")]
38 Fetch(#[from] FetchError),
39 #[error("namespace error: {0}")]
41 Namespace(#[from] NamespaceError),
42 #[error("track status error: {0}")]
44 TrackStatus(#[from] TrackStatusError),
45 #[error("publish flow error: {0}")]
47 PublishFlow(#[from] PublishFlowError),
48 #[error("setup error: {0}")]
50 Setup(#[from] SetupError),
51 #[error("unknown request ID: {0}")]
53 UnknownRequest(u64),
54 #[error("session not active")]
56 NotActive,
57 #[error("session is draining, no new requests allowed")]
59 Draining,
60}
61
62pub struct Endpoint {
65 role: Role,
66 session: SessionStateMachine,
67 request_ids: RequestIdAllocator,
68 advertised_max_id: u64,
70 subscriptions: HashMap<u64, SubscriptionStateMachine>,
71 fetches: HashMap<u64, FetchStateMachine>,
72 subscribe_namespaces: HashMap<u64, SubscribeNamespaceStateMachine>,
73 publish_namespaces: HashMap<u64, PublishNamespaceStateMachine>,
74 track_statuses: HashMap<u64, TrackStatusStateMachine>,
75 publishes: HashMap<u64, PublishStateMachine>,
76 negotiated_version: Option<VarInt>,
77 offered_versions: Vec<VarInt>,
78 goaway_uri: Option<Vec<u8>>,
79}
80
81impl Endpoint {
82 pub fn new(role: Role) -> Self {
84 Self {
85 role,
86 session: SessionStateMachine::new(),
87 request_ids: RequestIdAllocator::new(role),
88 advertised_max_id: 0,
89 subscriptions: HashMap::new(),
90 fetches: HashMap::new(),
91 subscribe_namespaces: HashMap::new(),
92 publish_namespaces: HashMap::new(),
93 track_statuses: HashMap::new(),
94 publishes: HashMap::new(),
95 negotiated_version: None,
96 offered_versions: Vec::new(),
97 goaway_uri: None,
98 }
99 }
100
101 pub fn role(&self) -> Role {
105 self.role
106 }
107
108 pub fn session_state(&self) -> SessionState {
110 self.session.state()
111 }
112
113 pub fn negotiated_version(&self) -> Option<VarInt> {
115 self.negotiated_version
116 }
117
118 pub fn goaway_uri(&self) -> Option<&[u8]> {
120 self.goaway_uri.as_deref()
121 }
122
123 pub fn is_blocked(&self) -> bool {
125 self.request_ids.is_blocked()
126 }
127
128 pub fn active_subscription_count(&self) -> usize {
130 self.subscriptions.len()
131 }
132
133 pub fn active_fetch_count(&self) -> usize {
135 self.fetches.len()
136 }
137
138 pub fn active_subscribe_namespace_count(&self) -> usize {
140 self.subscribe_namespaces.len()
141 }
142
143 pub fn active_publish_namespace_count(&self) -> usize {
145 self.publish_namespaces.len()
146 }
147
148 pub fn active_track_status_count(&self) -> usize {
150 self.track_statuses.len()
151 }
152
153 pub fn active_publish_count(&self) -> usize {
155 self.publishes.len()
156 }
157
158 pub fn connect(&mut self) -> Result<(), EndpointError> {
162 self.session.on_connect()?;
163 Ok(())
164 }
165
166 pub fn close(&mut self) -> Result<(), EndpointError> {
168 self.session.on_close()?;
169 Ok(())
170 }
171
172 pub fn send_client_setup(
176 &mut self,
177 versions: Vec<VarInt>,
178 parameters: Vec<KeyValuePair>,
179 ) -> Result<ControlMessage, EndpointError> {
180 self.offered_versions = versions.clone();
181 let msg = ClientSetup { supported_versions: versions, parameters };
182 setup::validate_client_setup(&msg)?;
183 Ok(ControlMessage::ClientSetup(msg))
184 }
185
186 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
190 setup::validate_server_setup(msg)?;
191 let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
192 self.negotiated_version = Some(version);
193 self.session.on_setup_complete()?;
194 for param in &msg.parameters {
196 if param.key == VarInt::from_u64(0x02).unwrap() {
197 if let KvpValue::Varint(v) = ¶m.value {
198 self.request_ids.update_max(v.into_inner())?;
199 }
200 }
201 }
202 Ok(())
203 }
204
205 pub fn receive_client_setup_and_respond(
209 &mut self,
210 client_setup: &ClientSetup,
211 selected_version: VarInt,
212 ) -> Result<ControlMessage, EndpointError> {
213 setup::validate_client_setup(client_setup)?;
214 let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
215 self.negotiated_version = Some(version);
216 self.session.on_setup_complete()?;
217 let msg = ServerSetup { selected_version: version, parameters: vec![] };
218 Ok(ControlMessage::ServerSetup(msg))
219 }
220
221 pub fn receive_max_request_id(&mut self, msg: &MaxRequestId) -> Result<(), EndpointError> {
225 self.request_ids.update_max(msg.request_id.into_inner())?;
226 Ok(())
227 }
228
229 pub fn send_max_request_id(&mut self, max_id: VarInt) -> Result<ControlMessage, EndpointError> {
232 let new_val = max_id.into_inner();
233 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
234 return Err(EndpointError::RequestId(RequestIdError::Decreased(
235 self.advertised_max_id,
236 new_val,
237 )));
238 }
239 self.advertised_max_id = new_val;
240 Ok(ControlMessage::MaxRequestId(MaxRequestId { request_id: max_id }))
241 }
242
243 pub fn send_requests_blocked(&self) -> Result<ControlMessage, EndpointError> {
247 let max_id = self.request_ids.max_id();
248 Ok(ControlMessage::RequestsBlocked(RequestsBlocked {
249 maximum_request_id: VarInt::from_u64(max_id).unwrap(),
250 }))
251 }
252
253 pub fn receive_requests_blocked(&self, _msg: &RequestsBlocked) -> Result<(), EndpointError> {
257 Ok(())
260 }
261
262 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
266 self.session.on_goaway()?;
267 self.goaway_uri = Some(msg.new_session_uri.clone());
268 Ok(())
269 }
270
271 fn require_active_or_err(&self) -> Result<(), EndpointError> {
274 match self.session.state() {
275 SessionState::Active => Ok(()),
276 SessionState::Draining => Err(EndpointError::Draining),
277 _ => Err(EndpointError::NotActive),
278 }
279 }
280
281 pub fn subscribe(
284 &mut self,
285 track_namespace: TrackNamespace,
286 track_name: Vec<u8>,
287 subscriber_priority: u8,
288 group_order: GroupOrder,
289 filter_type: FilterType,
290 ) -> Result<(VarInt, ControlMessage), EndpointError> {
291 self.require_active_or_err()?;
292 let req_id = self.request_ids.allocate()?;
293
294 let mut sm = SubscriptionStateMachine::new();
295 sm.on_subscribe_sent()?;
296 self.subscriptions.insert(req_id.into_inner(), sm);
297
298 let msg = ControlMessage::Subscribe(Subscribe {
299 request_id: req_id,
300 track_namespace,
301 track_name,
302 subscriber_priority,
303 group_order,
304 forward: Forward::Forward,
305 filter_type,
306 start_location: None,
307 end_group: None,
308 parameters: vec![],
309 });
310 Ok((req_id, msg))
311 }
312
313 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
315 let id = msg.request_id.into_inner();
316 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
317 sm.on_subscribe_ok()?;
318 Ok(())
319 }
320
321 pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
323 let id = msg.request_id.into_inner();
324 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
325 sm.on_subscribe_error()?;
326 Ok(())
327 }
328
329 pub fn unsubscribe(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
331 let id = request_id.into_inner();
332 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
333 sm.on_unsubscribe()?;
334 Ok(ControlMessage::Unsubscribe(Unsubscribe { request_id }))
335 }
336
337 pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
339 let id = msg.subscription_request_id.into_inner();
340 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
341 sm.on_subscribe_update()?;
342 Ok(())
343 }
344
345 pub fn receive_publish_done(&mut self, msg: &PublishDone) -> Result<(), EndpointError> {
347 let id = msg.request_id.into_inner();
348 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
349 sm.on_publish_done()?;
350 Ok(())
351 }
352
353 pub fn fetch(
357 &mut self,
358 track_namespace: TrackNamespace,
359 track_name: Vec<u8>,
360 start_group: VarInt,
361 start_object: VarInt,
362 ) -> Result<(VarInt, ControlMessage), EndpointError> {
363 self.require_active_or_err()?;
364 let req_id = self.request_ids.allocate()?;
365
366 let mut sm = FetchStateMachine::new();
367 sm.on_fetch_sent()?;
368 self.fetches.insert(req_id.into_inner(), sm);
369
370 let msg = ControlMessage::Fetch(Fetch {
371 request_id: req_id,
372 subscriber_priority: 128,
373 group_order: GroupOrder::Ascending,
374 fetch_type: message::FetchType::Standalone,
375 fetch_payload: message::FetchPayload::Standalone {
376 track_namespace,
377 track_name,
378 start_group,
379 start_object,
380 end_group: VarInt::from_u64(0).unwrap(),
381 end_object: VarInt::from_u64(0).unwrap(),
382 },
383 parameters: vec![],
384 });
385 Ok((req_id, msg))
386 }
387
388 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
390 let id = msg.request_id.into_inner();
391 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
392 sm.on_fetch_ok()?;
393 Ok(())
394 }
395
396 pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
398 let id = msg.request_id.into_inner();
399 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
400 sm.on_fetch_error()?;
401 Ok(())
402 }
403
404 pub fn fetch_cancel(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
406 let id = request_id.into_inner();
407 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
408 sm.on_fetch_cancel()?;
409 Ok(ControlMessage::FetchCancel(FetchCancel { request_id }))
410 }
411
412 pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
414 let id = request_id.into_inner();
415 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
416 sm.on_stream_fin()?;
417 Ok(())
418 }
419
420 pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
422 let id = request_id.into_inner();
423 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
424 sm.on_stream_reset()?;
425 Ok(())
426 }
427
428 pub fn subscribe_namespace(
432 &mut self,
433 track_namespace: TrackNamespace,
434 ) -> Result<(VarInt, ControlMessage), EndpointError> {
435 self.require_active_or_err()?;
436 let req_id = self.request_ids.allocate()?;
437
438 let mut sm = SubscribeNamespaceStateMachine::new();
439 sm.on_subscribe_namespace_sent()?;
440 self.subscribe_namespaces.insert(req_id.into_inner(), sm);
441
442 let msg = ControlMessage::SubscribeNamespace(SubscribeNamespace {
443 request_id: req_id,
444 track_namespace,
445 parameters: vec![],
446 });
447 Ok((req_id, msg))
448 }
449
450 pub fn receive_subscribe_namespace_ok(
452 &mut self,
453 msg: &SubscribeNamespaceOk,
454 ) -> Result<(), EndpointError> {
455 let id = msg.request_id.into_inner();
456 let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
457 sm.on_subscribe_namespace_ok()?;
458 Ok(())
459 }
460
461 pub fn receive_subscribe_namespace_error(
463 &mut self,
464 msg: &SubscribeNamespaceError,
465 ) -> Result<(), EndpointError> {
466 let id = msg.request_id.into_inner();
467 let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
468 sm.on_subscribe_namespace_error()?;
469 Ok(())
470 }
471
472 pub fn unsubscribe_namespace(
474 &mut self,
475 request_id: VarInt,
476 _track_namespace: TrackNamespace,
477 ) -> Result<ControlMessage, EndpointError> {
478 let id = request_id.into_inner();
479 let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
480 sm.on_unsubscribe_namespace()?;
481 let _ = request_id;
482 Ok(ControlMessage::UnsubscribeNamespace(UnsubscribeNamespace {
483 track_namespace_prefix: _track_namespace,
484 }))
485 }
486
487 pub fn publish_namespace(
491 &mut self,
492 track_namespace: TrackNamespace,
493 ) -> Result<(VarInt, ControlMessage), EndpointError> {
494 self.require_active_or_err()?;
495 let req_id = self.request_ids.allocate()?;
496
497 let mut sm = PublishNamespaceStateMachine::new();
498 sm.on_publish_namespace_sent()?;
499 self.publish_namespaces.insert(req_id.into_inner(), sm);
500
501 let msg = ControlMessage::PublishNamespace(PublishNamespace {
502 request_id: req_id,
503 track_namespace,
504 parameters: vec![],
505 });
506 Ok((req_id, msg))
507 }
508
509 pub fn receive_publish_namespace_ok(
511 &mut self,
512 msg: &PublishNamespaceOk,
513 ) -> Result<(), EndpointError> {
514 let id = msg.request_id.into_inner();
515 let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
516 sm.on_publish_namespace_ok()?;
517 Ok(())
518 }
519
520 pub fn receive_publish_namespace_error(
522 &mut self,
523 msg: &PublishNamespaceError,
524 ) -> Result<(), EndpointError> {
525 let id = msg.request_id.into_inner();
526 let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
527 sm.on_publish_namespace_error()?;
528 Ok(())
529 }
530
531 pub fn receive_publish_namespace_done(
539 &mut self,
540 _msg: &PublishNamespaceDone,
541 ) -> Result<(), EndpointError> {
542 for sm in self.publish_namespaces.values_mut() {
543 let _ = sm.on_publish_namespace_done();
546 }
547 Ok(())
548 }
549
550 pub fn publish_namespace_cancel(
552 &mut self,
553 request_id: VarInt,
554 reason_phrase: Vec<u8>,
555 ) -> Result<ControlMessage, EndpointError> {
556 let id = request_id.into_inner();
557 let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
558 sm.on_publish_namespace_cancel()?;
559 Ok(ControlMessage::PublishNamespaceCancel(PublishNamespaceCancel {
560 track_namespace: TrackNamespace(Vec::new()),
561 error_code: VarInt::from_u64(0).unwrap(),
562 reason_phrase,
563 }))
564 }
565
566 pub fn track_status(
570 &mut self,
571 track_namespace: TrackNamespace,
572 track_name: Vec<u8>,
573 ) -> Result<(VarInt, ControlMessage), EndpointError> {
574 self.require_active_or_err()?;
575 let req_id = self.request_ids.allocate()?;
576 let mut sm = TrackStatusStateMachine::new();
577 sm.on_track_status_sent()?;
578 self.track_statuses.insert(req_id.into_inner(), sm);
579 let msg = ControlMessage::TrackStatus(message::TrackStatus {
580 request_id: req_id,
581 track_namespace,
582 track_name,
583 subscriber_priority: 128,
584 group_order: GroupOrder::Ascending,
585 forward: Forward::Forward,
586 filter_type: FilterType::LargestObject,
587 start_location: None,
588 end_group: None,
589 parameters: vec![],
590 });
591 Ok((req_id, msg))
592 }
593
594 pub fn receive_track_status_ok(
596 &mut self,
597 msg: &message::TrackStatusOk,
598 ) -> Result<(), EndpointError> {
599 let id = msg.request_id.into_inner();
600 let sm = self.track_statuses.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
601 sm.on_track_status_ok()?;
602 Ok(())
603 }
604
605 pub fn receive_track_status_error(
607 &mut self,
608 msg: &message::TrackStatusError,
609 ) -> Result<(), EndpointError> {
610 let id = msg.request_id.into_inner();
611 let sm = self.track_statuses.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
612 sm.on_track_status_error()?;
613 Ok(())
614 }
615
616 pub fn publish(
620 &mut self,
621 track_namespace: TrackNamespace,
622 track_name: Vec<u8>,
623 forward: Forward,
624 ) -> Result<(VarInt, ControlMessage), EndpointError> {
625 self.require_active_or_err()?;
626 let req_id = self.request_ids.allocate()?;
627 let mut sm = PublishStateMachine::new();
628 sm.on_publish_sent()?;
629 self.publishes.insert(req_id.into_inner(), sm);
630 let msg = ControlMessage::Publish(message::Publish {
631 request_id: req_id,
632 track_namespace,
633 track_name,
634 track_alias: VarInt::from_u64(0).unwrap(),
635 group_order: GroupOrder::Ascending,
636 content_exists: ContentExists::NoLargestLocation,
637 largest_location: None,
638 forward,
639 parameters: vec![],
640 });
641 Ok((req_id, msg))
642 }
643
644 pub fn receive_publish_ok(&mut self, msg: &message::PublishOk) -> Result<(), EndpointError> {
646 let id = msg.request_id.into_inner();
647 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
648 sm.on_publish_ok()?;
649 Ok(())
650 }
651
652 pub fn send_publish_done(
654 &mut self,
655 request_id: VarInt,
656 status_code: VarInt,
657 reason_phrase: Vec<u8>,
658 ) -> Result<ControlMessage, EndpointError> {
659 let id = request_id.into_inner();
660 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
661 sm.on_publish_done_sent()?;
662 Ok(ControlMessage::PublishDone(PublishDone {
663 request_id,
664 status_code,
665 stream_count: VarInt::from_u64(0).unwrap(),
666 reason_phrase,
667 }))
668 }
669
670 pub fn send_publish_error(
675 &self,
676 request_id: VarInt,
677 error_code: VarInt,
678 reason_phrase: Vec<u8>,
679 ) -> Result<ControlMessage, EndpointError> {
680 Ok(ControlMessage::PublishError(message::PublishError {
681 request_id,
682 error_code,
683 reason_phrase,
684 }))
685 }
686
687 pub fn receive_publish_error(
691 &mut self,
692 msg: &message::PublishError,
693 ) -> Result<(), EndpointError> {
694 let id = msg.request_id.into_inner();
695 if let Some(sm) = self.publishes.get_mut(&id) {
697 sm.on_publish_error()?;
698 return Ok(());
699 }
700 if let Some(sm) = self.subscriptions.get_mut(&id) {
702 sm.on_subscribe_error()?;
703 }
704 Ok(())
705 }
706
707 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
711 match msg {
712 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
713 ControlMessage::MaxRequestId(ref m) => self.receive_max_request_id(m),
714 ControlMessage::RequestsBlocked(ref m) => self.receive_requests_blocked(m),
715 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
716 ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
717 ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
718 ControlMessage::PublishDone(ref m) => self.receive_publish_done(m),
719 ControlMessage::PublishOk(ref m) => self.receive_publish_ok(m),
720 ControlMessage::PublishError(ref m) => self.receive_publish_error(m),
721 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
722 ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
723 ControlMessage::SubscribeNamespaceOk(ref m) => self.receive_subscribe_namespace_ok(m),
724 ControlMessage::SubscribeNamespaceError(ref m) => {
725 self.receive_subscribe_namespace_error(m)
726 }
727 ControlMessage::PublishNamespaceOk(ref m) => self.receive_publish_namespace_ok(m),
728 ControlMessage::PublishNamespaceError(ref m) => self.receive_publish_namespace_error(m),
729 ControlMessage::PublishNamespaceDone(ref m) => self.receive_publish_namespace_done(m),
730 ControlMessage::TrackStatusOk(ref m) => self.receive_track_status_ok(m),
731 ControlMessage::TrackStatusError(ref m) => self.receive_track_status_error(m),
732 _ => Ok(()),
733 }
734 }
735}