1use std::collections::HashMap;
2
3use crate::draft15::fetch::{FetchError, FetchStateMachine};
4use crate::draft15::namespace::{
5 NamespaceError, PublishNamespaceStateMachine, SubscribeNamespaceStateMachine,
6};
7use crate::draft15::publish::{PublishError as PublishFlowError, PublishStateMachine};
8use crate::draft15::session::request_id::{RequestIdAllocator, RequestIdError, Role};
9use crate::draft15::session::setup::{self, SetupError};
10use crate::draft15::session::state::{SessionError, SessionState, SessionStateMachine};
11use crate::draft15::subscription::{SubscriptionError, SubscriptionStateMachine};
12use crate::draft15::track_status::{TrackStatusError, TrackStatusStateMachine};
13use moqtap_codec::draft15::message::{
14 self, ClientSetup, ControlMessage, Fetch, FetchCancel, FetchPayload, FetchType, GoAway,
15 MaxRequestId, PublishDone, PublishNamespace, PublishNamespaceCancel, PublishNamespaceDone,
16 RequestError, RequestOk, RequestsBlocked, ServerSetup, Subscribe, SubscribeNamespace,
17 SubscribeOk, SubscribeUpdate, Unsubscribe, UnsubscribeNamespace,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23#[derive(Debug, thiserror::Error)]
25pub enum EndpointError {
26 #[error("session error: {0}")]
28 Session(#[from] SessionError),
29 #[error("request ID error: {0}")]
31 RequestId(#[from] RequestIdError),
32 #[error("subscription error: {0}")]
34 Subscription(#[from] SubscriptionError),
35 #[error("fetch error: {0}")]
37 Fetch(#[from] FetchError),
38 #[error("namespace error: {0}")]
40 Namespace(#[from] NamespaceError),
41 #[error("track status error: {0}")]
43 TrackStatus(#[from] TrackStatusError),
44 #[error("publish flow error: {0}")]
46 PublishFlow(#[from] PublishFlowError),
47 #[error("setup error: {0}")]
49 Setup(#[from] SetupError),
50 #[error("unknown request ID: {0}")]
52 UnknownRequest(u64),
53 #[error("session not active")]
55 NotActive,
56 #[error("session is draining, no new requests allowed")]
58 Draining,
59}
60
61pub struct Endpoint {
64 role: Role,
65 session: SessionStateMachine,
66 request_ids: RequestIdAllocator,
67 advertised_max_id: u64,
69 subscriptions: HashMap<u64, SubscriptionStateMachine>,
70 fetches: HashMap<u64, FetchStateMachine>,
71 subscribe_namespaces: HashMap<u64, SubscribeNamespaceStateMachine>,
72 publish_namespaces: HashMap<u64, PublishNamespaceStateMachine>,
73 publish_namespace_namespaces: HashMap<u64, TrackNamespace>,
77 track_statuses: HashMap<u64, TrackStatusStateMachine>,
78 publishes: HashMap<u64, PublishStateMachine>,
79 goaway_uri: Option<Vec<u8>>,
80}
81
82impl Endpoint {
83 pub fn new(role: Role) -> Self {
85 Self {
86 role,
87 session: SessionStateMachine::new(),
88 request_ids: RequestIdAllocator::new(role),
89 advertised_max_id: 0,
90 subscriptions: HashMap::new(),
91 fetches: HashMap::new(),
92 subscribe_namespaces: HashMap::new(),
93 publish_namespaces: HashMap::new(),
94 publish_namespace_namespaces: HashMap::new(),
95 track_statuses: HashMap::new(),
96 publishes: HashMap::new(),
97 goaway_uri: None,
98 }
99 }
100
101 pub fn role(&self) -> Role {
105 self.role
106 }
107
108 pub fn session_state(&self) -> SessionState {
110 self.session.state()
111 }
112
113 pub fn goaway_uri(&self) -> Option<&[u8]> {
115 self.goaway_uri.as_deref()
116 }
117
118 pub fn is_blocked(&self) -> bool {
120 self.request_ids.is_blocked()
121 }
122
123 pub fn active_subscription_count(&self) -> usize {
125 self.subscriptions.len()
126 }
127
128 pub fn active_fetch_count(&self) -> usize {
130 self.fetches.len()
131 }
132
133 pub fn active_subscribe_namespace_count(&self) -> usize {
135 self.subscribe_namespaces.len()
136 }
137
138 pub fn active_publish_namespace_count(&self) -> usize {
140 self.publish_namespaces.len()
141 }
142
143 pub fn active_track_status_count(&self) -> usize {
145 self.track_statuses.len()
146 }
147
148 pub fn active_publish_count(&self) -> usize {
150 self.publishes.len()
151 }
152
153 pub fn connect(&mut self) -> Result<(), EndpointError> {
157 self.session.on_connect()?;
158 Ok(())
159 }
160
161 pub fn close(&mut self) -> Result<(), EndpointError> {
163 self.session.on_close()?;
164 Ok(())
165 }
166
167 pub fn send_client_setup(
172 &mut self,
173 parameters: Vec<KeyValuePair>,
174 ) -> Result<ControlMessage, EndpointError> {
175 let msg = ClientSetup { parameters };
176 setup::validate_client_setup(&msg)?;
177 Ok(ControlMessage::ClientSetup(msg))
178 }
179
180 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
184 setup::validate_server_setup(msg)?;
185 self.session.on_setup_complete()?;
186 for param in &msg.parameters {
188 if param.key == VarInt::from_u64(0x02).unwrap() {
189 if let KvpValue::Varint(v) = ¶m.value {
190 self.request_ids.update_max(v.into_inner())?;
191 }
192 }
193 }
194 Ok(())
195 }
196
197 pub fn receive_client_setup_and_respond(
203 &mut self,
204 client_setup: &ClientSetup,
205 ) -> Result<ControlMessage, EndpointError> {
206 setup::validate_client_setup(client_setup)?;
207 self.session.on_setup_complete()?;
208 let msg = ServerSetup { parameters: vec![] };
209 Ok(ControlMessage::ServerSetup(msg))
210 }
211
212 pub fn receive_max_request_id(&mut self, msg: &MaxRequestId) -> Result<(), EndpointError> {
216 self.request_ids.update_max(msg.request_id.into_inner())?;
217 Ok(())
218 }
219
220 pub fn send_max_request_id(&mut self, max_id: VarInt) -> Result<ControlMessage, EndpointError> {
223 let new_val = max_id.into_inner();
224 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
225 return Err(EndpointError::RequestId(RequestIdError::Decreased(
226 self.advertised_max_id,
227 new_val,
228 )));
229 }
230 self.advertised_max_id = new_val;
231 Ok(ControlMessage::MaxRequestId(MaxRequestId { request_id: max_id }))
232 }
233
234 pub fn send_requests_blocked(&self) -> Result<ControlMessage, EndpointError> {
238 let max_id = self.request_ids.max_id();
239 Ok(ControlMessage::RequestsBlocked(RequestsBlocked {
240 maximum_request_id: VarInt::from_u64(max_id).unwrap(),
241 }))
242 }
243
244 pub fn receive_requests_blocked(&self, _msg: &RequestsBlocked) -> Result<(), EndpointError> {
246 Ok(())
247 }
248
249 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
253 self.session.on_goaway()?;
254 self.goaway_uri = Some(msg.new_session_uri.clone());
255 Ok(())
256 }
257
258 fn require_active_or_err(&self) -> Result<(), EndpointError> {
261 match self.session.state() {
262 SessionState::Active => Ok(()),
263 SessionState::Draining => Err(EndpointError::Draining),
264 _ => Err(EndpointError::NotActive),
265 }
266 }
267
268 pub fn subscribe(
272 &mut self,
273 track_namespace: TrackNamespace,
274 track_name: Vec<u8>,
275 parameters: Vec<KeyValuePair>,
276 ) -> Result<(VarInt, ControlMessage), EndpointError> {
277 self.require_active_or_err()?;
278 let req_id = self.request_ids.allocate()?;
279
280 let mut sm = SubscriptionStateMachine::new();
281 sm.on_subscribe_sent()?;
282 self.subscriptions.insert(req_id.into_inner(), sm);
283
284 let msg = ControlMessage::Subscribe(Subscribe {
285 request_id: req_id,
286 track_namespace,
287 track_name,
288 parameters,
289 });
290 Ok((req_id, msg))
291 }
292
293 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
296 let id = msg.request_id.into_inner();
297 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
298 sm.on_subscribe_ok()?;
299 Ok(())
300 }
301
302 pub fn unsubscribe(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
304 let id = request_id.into_inner();
305 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
306 sm.on_unsubscribe()?;
307 Ok(ControlMessage::Unsubscribe(Unsubscribe { request_id }))
308 }
309
310 pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
313 let id = msg.subscription_request_id.into_inner();
314 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
315 sm.on_subscribe_update()?;
316 Ok(())
317 }
318
319 pub fn receive_publish_done(&mut self, msg: &PublishDone) -> Result<(), EndpointError> {
323 let id = msg.request_id.into_inner();
324 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
325 sm.on_publish_done()?;
326 Ok(())
327 }
328
329 pub fn fetch(
334 &mut self,
335 track_namespace: TrackNamespace,
336 track_name: Vec<u8>,
337 start_group: VarInt,
338 start_object: VarInt,
339 end_group: VarInt,
340 end_object: VarInt,
341 ) -> Result<(VarInt, ControlMessage), EndpointError> {
342 self.require_active_or_err()?;
343 let req_id = self.request_ids.allocate()?;
344
345 let mut sm = FetchStateMachine::new();
346 sm.on_fetch_sent()?;
347 self.fetches.insert(req_id.into_inner(), sm);
348
349 let msg = ControlMessage::Fetch(Fetch {
350 request_id: req_id,
351 fetch_type: FetchType::Standalone,
352 fetch_payload: FetchPayload::Standalone {
353 track_namespace,
354 track_name,
355 start_group,
356 start_object,
357 end_group,
358 end_object,
359 },
360 parameters: vec![],
361 });
362 Ok((req_id, msg))
363 }
364
365 pub fn joining_fetch(
367 &mut self,
368 joining_request_id: VarInt,
369 joining_start: VarInt,
370 ) -> Result<(VarInt, ControlMessage), EndpointError> {
371 self.require_active_or_err()?;
372 let req_id = self.request_ids.allocate()?;
373
374 let mut sm = FetchStateMachine::new();
375 sm.on_fetch_sent()?;
376 self.fetches.insert(req_id.into_inner(), sm);
377
378 let msg = ControlMessage::Fetch(Fetch {
379 request_id: req_id,
380 fetch_type: FetchType::RelativeJoining,
381 fetch_payload: FetchPayload::Joining { joining_request_id, joining_start },
382 parameters: vec![],
383 });
384 Ok((req_id, msg))
385 }
386
387 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
390 let id = msg.request_id.into_inner();
391 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
392 sm.on_fetch_ok()?;
393 Ok(())
394 }
395
396 pub fn fetch_cancel(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
398 let id = request_id.into_inner();
399 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
400 sm.on_fetch_cancel()?;
401 Ok(ControlMessage::FetchCancel(FetchCancel { request_id }))
402 }
403
404 pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
406 let id = request_id.into_inner();
407 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
408 sm.on_stream_fin()?;
409 Ok(())
410 }
411
412 pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
414 let id = request_id.into_inner();
415 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
416 sm.on_stream_reset()?;
417 Ok(())
418 }
419
420 pub fn subscribe_namespace(
424 &mut self,
425 namespace_prefix: TrackNamespace,
426 parameters: Vec<KeyValuePair>,
427 ) -> Result<(VarInt, ControlMessage), EndpointError> {
428 self.require_active_or_err()?;
429 let req_id = self.request_ids.allocate()?;
430
431 let mut sm = SubscribeNamespaceStateMachine::new();
432 sm.on_subscribe_namespace_sent()?;
433 self.subscribe_namespaces.insert(req_id.into_inner(), sm);
434
435 let msg = ControlMessage::SubscribeNamespace(SubscribeNamespace {
436 request_id: req_id,
437 namespace_prefix,
438 parameters,
439 });
440 Ok((req_id, msg))
441 }
442
443 pub fn unsubscribe_namespace(
445 &mut self,
446 request_id: VarInt,
447 ) -> Result<ControlMessage, EndpointError> {
448 let id = request_id.into_inner();
449 let sm = self.subscribe_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
450 sm.on_unsubscribe_namespace()?;
451 Ok(ControlMessage::UnsubscribeNamespace(UnsubscribeNamespace { request_id }))
452 }
453
454 pub fn publish_namespace(
458 &mut self,
459 track_namespace: TrackNamespace,
460 parameters: Vec<KeyValuePair>,
461 ) -> Result<(VarInt, ControlMessage), EndpointError> {
462 self.require_active_or_err()?;
463 let req_id = self.request_ids.allocate()?;
464
465 let mut sm = PublishNamespaceStateMachine::new();
466 sm.on_publish_namespace_sent()?;
467 self.publish_namespaces.insert(req_id.into_inner(), sm);
468 self.publish_namespace_namespaces.insert(req_id.into_inner(), track_namespace.clone());
469
470 let msg = ControlMessage::PublishNamespace(PublishNamespace {
471 request_id: req_id,
472 track_namespace,
473 parameters,
474 });
475 Ok((req_id, msg))
476 }
477
478 pub fn receive_publish_namespace_done(
481 &mut self,
482 msg: &PublishNamespaceDone,
483 ) -> Result<(), EndpointError> {
484 for (id, sm) in self.publish_namespaces.iter_mut() {
485 if let Some(ns) = self.publish_namespace_namespaces.get(id) {
486 if *ns == msg.track_namespace {
487 sm.on_publish_namespace_done()?;
488 return Ok(());
489 }
490 }
491 }
492 Ok(())
494 }
495
496 pub fn publish_namespace_cancel(
499 &mut self,
500 track_namespace: TrackNamespace,
501 error_code: VarInt,
502 reason_phrase: Vec<u8>,
503 ) -> Result<ControlMessage, EndpointError> {
504 let req_id = self
506 .publish_namespace_namespaces
507 .iter()
508 .find(|(_, ns)| **ns == track_namespace)
509 .map(|(id, _)| *id)
510 .ok_or(EndpointError::UnknownRequest(0))?;
511 let sm = self
512 .publish_namespaces
513 .get_mut(&req_id)
514 .ok_or(EndpointError::UnknownRequest(req_id))?;
515 sm.on_publish_namespace_cancel()?;
516 Ok(ControlMessage::PublishNamespaceCancel(PublishNamespaceCancel {
517 track_namespace,
518 error_code,
519 reason_phrase,
520 }))
521 }
522
523 pub fn track_status(
528 &mut self,
529 track_namespace: TrackNamespace,
530 track_name: Vec<u8>,
531 parameters: Vec<KeyValuePair>,
532 ) -> Result<(VarInt, ControlMessage), EndpointError> {
533 self.require_active_or_err()?;
534 let req_id = self.request_ids.allocate()?;
535 let mut sm = TrackStatusStateMachine::new();
536 sm.on_track_status_sent()?;
537 self.track_statuses.insert(req_id.into_inner(), sm);
538 let msg = ControlMessage::TrackStatus(message::TrackStatus {
539 request_id: req_id,
540 track_namespace,
541 track_name,
542 parameters,
543 });
544 Ok((req_id, msg))
545 }
546
547 pub fn publish(
552 &mut self,
553 track_namespace: TrackNamespace,
554 track_name: Vec<u8>,
555 track_alias: VarInt,
556 parameters: Vec<KeyValuePair>,
557 ) -> Result<(VarInt, ControlMessage), EndpointError> {
558 self.require_active_or_err()?;
559 let req_id = self.request_ids.allocate()?;
560 let mut sm = PublishStateMachine::new();
561 sm.on_publish_sent()?;
562 self.publishes.insert(req_id.into_inner(), sm);
563 let msg = ControlMessage::Publish(message::Publish {
564 request_id: req_id,
565 track_namespace,
566 track_name,
567 track_alias,
568 parameters,
569 });
570 Ok((req_id, msg))
571 }
572
573 pub fn receive_publish_ok(&mut self, msg: &message::PublishOk) -> Result<(), EndpointError> {
576 let id = msg.request_id.into_inner();
577 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
578 sm.on_publish_ok()?;
579 Ok(())
580 }
581
582 pub fn send_publish_done(
585 &mut self,
586 request_id: VarInt,
587 status_code: VarInt,
588 stream_count: VarInt,
589 reason_phrase: Vec<u8>,
590 ) -> Result<ControlMessage, EndpointError> {
591 let id = request_id.into_inner();
592 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
593 sm.on_publish_done_sent()?;
594 Ok(ControlMessage::PublishDone(PublishDone {
595 request_id,
596 status_code,
597 stream_count,
598 reason_phrase,
599 }))
600 }
601
602 pub fn receive_request_ok(&mut self, msg: &RequestOk) -> Result<(), EndpointError> {
607 let id = msg.request_id.into_inner();
608 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
610 sm.on_subscribe_namespace_ok()?;
611 return Ok(());
612 }
613 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
615 sm.on_publish_namespace_ok()?;
616 return Ok(());
617 }
618 if let Some(sm) = self.track_statuses.get_mut(&id) {
620 sm.on_track_status_ok()?;
621 return Ok(());
622 }
623 Err(EndpointError::UnknownRequest(id))
624 }
625
626 pub fn receive_request_error(&mut self, msg: &RequestError) -> Result<(), EndpointError> {
628 let id = msg.request_id.into_inner();
629 if let Some(sm) = self.subscriptions.get_mut(&id) {
631 sm.on_subscribe_error()?;
632 return Ok(());
633 }
634 if let Some(sm) = self.fetches.get_mut(&id) {
636 sm.on_fetch_error()?;
637 return Ok(());
638 }
639 if let Some(sm) = self.publishes.get_mut(&id) {
641 sm.on_publish_error()?;
642 return Ok(());
643 }
644 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
646 sm.on_subscribe_namespace_error()?;
647 return Ok(());
648 }
649 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
651 sm.on_publish_namespace_error()?;
652 return Ok(());
653 }
654 if let Some(sm) = self.track_statuses.get_mut(&id) {
656 sm.on_track_status_error()?;
657 return Ok(());
658 }
659 Ok(())
661 }
662
663 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
667 match msg {
668 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
669 ControlMessage::MaxRequestId(ref m) => self.receive_max_request_id(m),
670 ControlMessage::RequestsBlocked(ref m) => self.receive_requests_blocked(m),
671 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
672 ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
673 ControlMessage::PublishDone(ref m) => self.receive_publish_done(m),
674 ControlMessage::PublishOk(ref m) => self.receive_publish_ok(m),
675 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
676 ControlMessage::RequestOk(ref m) => self.receive_request_ok(m),
677 ControlMessage::RequestError(ref m) => self.receive_request_error(m),
678 ControlMessage::PublishNamespaceDone(ref m) => self.receive_publish_namespace_done(m),
679 _ => Ok(()),
680 }
681 }
682}