1use std::collections::HashMap;
2
3use crate::draft16::fetch::{FetchError, FetchStateMachine};
4use crate::draft16::namespace::{
5 NamespaceError, PublishNamespaceStateMachine, SubscribeNamespaceStateMachine,
6};
7use crate::draft16::publish::{PublishError as PublishFlowError, PublishStateMachine};
8use crate::draft16::session::request_id::{RequestIdAllocator, RequestIdError, Role};
9use crate::draft16::session::setup::{self, SetupError};
10use crate::draft16::session::state::{SessionError, SessionState, SessionStateMachine};
11use crate::draft16::subscription::{SubscriptionError, SubscriptionStateMachine};
12use crate::draft16::track_status::{TrackStatusError, TrackStatusStateMachine};
13use moqtap_codec::draft16::message::{
14 self, ClientSetup, ControlMessage, Fetch, FetchCancel, FetchPayload, FetchType, GoAway,
15 MaxRequestId, PublishDone, PublishNamespace, PublishNamespaceCancel, PublishNamespaceDone,
16 RequestError, RequestOk, RequestUpdate, RequestsBlocked, ServerSetup, Subscribe,
17 SubscribeNamespace, SubscribeOk, Unsubscribe,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23#[derive(Debug, thiserror::Error)]
25pub enum EndpointError {
26 #[error("session error: {0}")]
28 Session(#[from] SessionError),
29 #[error("request ID error: {0}")]
31 RequestId(#[from] RequestIdError),
32 #[error("subscription error: {0}")]
34 Subscription(#[from] SubscriptionError),
35 #[error("fetch error: {0}")]
37 Fetch(#[from] FetchError),
38 #[error("namespace error: {0}")]
40 Namespace(#[from] NamespaceError),
41 #[error("track status error: {0}")]
43 TrackStatus(#[from] TrackStatusError),
44 #[error("publish flow error: {0}")]
46 PublishFlow(#[from] PublishFlowError),
47 #[error("setup error: {0}")]
49 Setup(#[from] SetupError),
50 #[error("unknown request ID: {0}")]
52 UnknownRequest(u64),
53 #[error("session not active")]
55 NotActive,
56 #[error("session is draining, no new requests allowed")]
58 Draining,
59}
60
61pub struct Endpoint {
64 role: Role,
65 session: SessionStateMachine,
66 request_ids: RequestIdAllocator,
67 advertised_max_id: u64,
69 subscriptions: HashMap<u64, SubscriptionStateMachine>,
70 fetches: HashMap<u64, FetchStateMachine>,
71 subscribe_namespaces: HashMap<u64, SubscribeNamespaceStateMachine>,
72 publish_namespaces: HashMap<u64, PublishNamespaceStateMachine>,
73 track_statuses: HashMap<u64, TrackStatusStateMachine>,
74 publishes: HashMap<u64, PublishStateMachine>,
75 goaway_uri: Option<Vec<u8>>,
76}
77
78impl Endpoint {
79 pub fn new(role: Role) -> Self {
81 Self {
82 role,
83 session: SessionStateMachine::new(),
84 request_ids: RequestIdAllocator::new(role),
85 advertised_max_id: 0,
86 subscriptions: HashMap::new(),
87 fetches: HashMap::new(),
88 subscribe_namespaces: HashMap::new(),
89 publish_namespaces: HashMap::new(),
90 track_statuses: HashMap::new(),
91 publishes: HashMap::new(),
92 goaway_uri: None,
93 }
94 }
95
96 pub fn role(&self) -> Role {
100 self.role
101 }
102
103 pub fn session_state(&self) -> SessionState {
105 self.session.state()
106 }
107
108 pub fn goaway_uri(&self) -> Option<&[u8]> {
110 self.goaway_uri.as_deref()
111 }
112
113 pub fn is_blocked(&self) -> bool {
115 self.request_ids.is_blocked()
116 }
117
118 pub fn active_subscription_count(&self) -> usize {
120 self.subscriptions.len()
121 }
122
123 pub fn active_fetch_count(&self) -> usize {
125 self.fetches.len()
126 }
127
128 pub fn active_subscribe_namespace_count(&self) -> usize {
130 self.subscribe_namespaces.len()
131 }
132
133 pub fn active_publish_namespace_count(&self) -> usize {
135 self.publish_namespaces.len()
136 }
137
138 pub fn active_track_status_count(&self) -> usize {
140 self.track_statuses.len()
141 }
142
143 pub fn active_publish_count(&self) -> usize {
145 self.publishes.len()
146 }
147
148 pub fn connect(&mut self) -> Result<(), EndpointError> {
152 self.session.on_connect()?;
153 Ok(())
154 }
155
156 pub fn close(&mut self) -> Result<(), EndpointError> {
158 self.session.on_close()?;
159 Ok(())
160 }
161
162 pub fn send_client_setup(
167 &mut self,
168 parameters: Vec<KeyValuePair>,
169 ) -> Result<ControlMessage, EndpointError> {
170 let msg = ClientSetup { parameters };
171 setup::validate_client_setup(&msg)?;
172 Ok(ControlMessage::ClientSetup(msg))
173 }
174
175 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
179 setup::validate_server_setup(msg)?;
180 self.session.on_setup_complete()?;
181 for param in &msg.parameters {
183 if param.key == VarInt::from_u64(0x02).unwrap() {
184 if let KvpValue::Varint(v) = ¶m.value {
185 self.request_ids.update_max(v.into_inner())?;
186 }
187 }
188 }
189 Ok(())
190 }
191
192 pub fn receive_client_setup_and_respond(
198 &mut self,
199 client_setup: &ClientSetup,
200 ) -> Result<ControlMessage, EndpointError> {
201 setup::validate_client_setup(client_setup)?;
202 self.session.on_setup_complete()?;
203 let msg = ServerSetup { parameters: vec![] };
204 Ok(ControlMessage::ServerSetup(msg))
205 }
206
207 pub fn receive_max_request_id(&mut self, msg: &MaxRequestId) -> Result<(), EndpointError> {
211 self.request_ids.update_max(msg.request_id.into_inner())?;
212 Ok(())
213 }
214
215 pub fn send_max_request_id(&mut self, max_id: VarInt) -> Result<ControlMessage, EndpointError> {
218 let new_val = max_id.into_inner();
219 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
220 return Err(EndpointError::RequestId(RequestIdError::Decreased(
221 self.advertised_max_id,
222 new_val,
223 )));
224 }
225 self.advertised_max_id = new_val;
226 Ok(ControlMessage::MaxRequestId(MaxRequestId { request_id: max_id }))
227 }
228
229 pub fn send_requests_blocked(&self) -> Result<ControlMessage, EndpointError> {
233 let max_id = self.request_ids.max_id();
234 Ok(ControlMessage::RequestsBlocked(RequestsBlocked {
235 maximum_request_id: VarInt::from_u64(max_id).unwrap(),
236 }))
237 }
238
239 pub fn receive_requests_blocked(&self, _msg: &RequestsBlocked) -> Result<(), EndpointError> {
241 Ok(())
242 }
243
244 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
248 self.session.on_goaway()?;
249 self.goaway_uri = Some(msg.new_session_uri.clone());
250 Ok(())
251 }
252
253 fn require_active_or_err(&self) -> Result<(), EndpointError> {
256 match self.session.state() {
257 SessionState::Active => Ok(()),
258 SessionState::Draining => Err(EndpointError::Draining),
259 _ => Err(EndpointError::NotActive),
260 }
261 }
262
263 pub fn subscribe(
267 &mut self,
268 track_namespace: TrackNamespace,
269 track_name: Vec<u8>,
270 parameters: Vec<KeyValuePair>,
271 ) -> Result<(VarInt, ControlMessage), EndpointError> {
272 self.require_active_or_err()?;
273 let req_id = self.request_ids.allocate()?;
274
275 let mut sm = SubscriptionStateMachine::new();
276 sm.on_subscribe_sent()?;
277 self.subscriptions.insert(req_id.into_inner(), sm);
278
279 let msg = ControlMessage::Subscribe(Subscribe {
280 request_id: req_id,
281 track_namespace,
282 track_name,
283 parameters,
284 });
285 Ok((req_id, msg))
286 }
287
288 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
291 let id = msg.request_id.into_inner();
292 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
293 sm.on_subscribe_ok()?;
294 Ok(())
295 }
296
297 pub fn unsubscribe(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
299 let id = request_id.into_inner();
300 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
301 sm.on_unsubscribe()?;
302 Ok(ControlMessage::Unsubscribe(Unsubscribe { request_id }))
303 }
304
305 pub fn receive_request_update(&mut self, msg: &RequestUpdate) -> Result<(), EndpointError> {
309 let id = msg.existing_request_id.into_inner();
310 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
311 sm.on_subscribe_update()?;
312 Ok(())
313 }
314
315 pub fn receive_publish_done(&mut self, msg: &PublishDone) -> Result<(), EndpointError> {
319 let id = msg.request_id.into_inner();
320 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
321 sm.on_publish_done()?;
322 Ok(())
323 }
324
325 pub fn fetch(
330 &mut self,
331 track_namespace: TrackNamespace,
332 track_name: Vec<u8>,
333 start_group: VarInt,
334 start_object: VarInt,
335 end_group: VarInt,
336 end_object: VarInt,
337 ) -> Result<(VarInt, ControlMessage), EndpointError> {
338 self.require_active_or_err()?;
339 let req_id = self.request_ids.allocate()?;
340
341 let mut sm = FetchStateMachine::new();
342 sm.on_fetch_sent()?;
343 self.fetches.insert(req_id.into_inner(), sm);
344
345 let msg = ControlMessage::Fetch(Fetch {
346 request_id: req_id,
347 fetch_type: FetchType::Standalone,
348 fetch_payload: FetchPayload::Standalone {
349 track_namespace,
350 track_name,
351 start_group,
352 start_object,
353 end_group,
354 end_object,
355 },
356 parameters: vec![],
357 });
358 Ok((req_id, msg))
359 }
360
361 pub fn joining_fetch(
363 &mut self,
364 joining_request_id: VarInt,
365 joining_start: VarInt,
366 ) -> Result<(VarInt, ControlMessage), EndpointError> {
367 self.require_active_or_err()?;
368 let req_id = self.request_ids.allocate()?;
369
370 let mut sm = FetchStateMachine::new();
371 sm.on_fetch_sent()?;
372 self.fetches.insert(req_id.into_inner(), sm);
373
374 let msg = ControlMessage::Fetch(Fetch {
375 request_id: req_id,
376 fetch_type: FetchType::RelativeJoining,
377 fetch_payload: FetchPayload::Joining { joining_request_id, joining_start },
378 parameters: vec![],
379 });
380 Ok((req_id, msg))
381 }
382
383 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
386 let id = msg.request_id.into_inner();
387 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
388 sm.on_fetch_ok()?;
389 Ok(())
390 }
391
392 pub fn fetch_cancel(&mut self, request_id: VarInt) -> Result<ControlMessage, EndpointError> {
394 let id = request_id.into_inner();
395 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
396 sm.on_fetch_cancel()?;
397 Ok(ControlMessage::FetchCancel(FetchCancel { request_id }))
398 }
399
400 pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
402 let id = request_id.into_inner();
403 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
404 sm.on_stream_fin()?;
405 Ok(())
406 }
407
408 pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
410 let id = request_id.into_inner();
411 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
412 sm.on_stream_reset()?;
413 Ok(())
414 }
415
416 pub fn subscribe_namespace(
421 &mut self,
422 namespace_prefix: TrackNamespace,
423 subscribe_options: VarInt,
424 parameters: Vec<KeyValuePair>,
425 ) -> Result<(VarInt, ControlMessage), EndpointError> {
426 self.require_active_or_err()?;
427 let req_id = self.request_ids.allocate()?;
428
429 let mut sm = SubscribeNamespaceStateMachine::new();
430 sm.on_subscribe_namespace_sent()?;
431 self.subscribe_namespaces.insert(req_id.into_inner(), sm);
432
433 let msg = ControlMessage::SubscribeNamespace(SubscribeNamespace {
434 request_id: req_id,
435 namespace_prefix,
436 subscribe_options,
437 parameters,
438 });
439 Ok((req_id, msg))
440 }
441
442 pub fn publish_namespace(
446 &mut self,
447 track_namespace: TrackNamespace,
448 parameters: Vec<KeyValuePair>,
449 ) -> Result<(VarInt, ControlMessage), EndpointError> {
450 self.require_active_or_err()?;
451 let req_id = self.request_ids.allocate()?;
452
453 let mut sm = PublishNamespaceStateMachine::new();
454 sm.on_publish_namespace_sent()?;
455 self.publish_namespaces.insert(req_id.into_inner(), sm);
456
457 let msg = ControlMessage::PublishNamespace(PublishNamespace {
458 request_id: req_id,
459 track_namespace,
460 parameters,
461 });
462 Ok((req_id, msg))
463 }
464
465 pub fn receive_publish_namespace_done(
468 &mut self,
469 msg: &PublishNamespaceDone,
470 ) -> Result<(), EndpointError> {
471 let id = msg.request_id.into_inner();
472 let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
473 sm.on_publish_namespace_done()?;
474 Ok(())
475 }
476
477 pub fn publish_namespace_cancel(
480 &mut self,
481 request_id: VarInt,
482 error_code: VarInt,
483 reason_phrase: Vec<u8>,
484 ) -> Result<ControlMessage, EndpointError> {
485 let id = request_id.into_inner();
486 let sm = self.publish_namespaces.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
487 sm.on_publish_namespace_cancel()?;
488 Ok(ControlMessage::PublishNamespaceCancel(PublishNamespaceCancel {
489 request_id,
490 error_code,
491 reason_phrase,
492 }))
493 }
494
495 pub fn track_status(
500 &mut self,
501 track_namespace: TrackNamespace,
502 track_name: Vec<u8>,
503 parameters: Vec<KeyValuePair>,
504 ) -> Result<(VarInt, ControlMessage), EndpointError> {
505 self.require_active_or_err()?;
506 let req_id = self.request_ids.allocate()?;
507 let mut sm = TrackStatusStateMachine::new();
508 sm.on_track_status_sent()?;
509 self.track_statuses.insert(req_id.into_inner(), sm);
510 let msg = ControlMessage::TrackStatus(message::TrackStatus {
511 request_id: req_id,
512 track_namespace,
513 track_name,
514 parameters,
515 });
516 Ok((req_id, msg))
517 }
518
519 pub fn publish(
524 &mut self,
525 track_namespace: TrackNamespace,
526 track_name: Vec<u8>,
527 track_alias: VarInt,
528 track_extensions: Vec<KeyValuePair>,
529 parameters: Vec<KeyValuePair>,
530 ) -> Result<(VarInt, ControlMessage), EndpointError> {
531 self.require_active_or_err()?;
532 let req_id = self.request_ids.allocate()?;
533 let mut sm = PublishStateMachine::new();
534 sm.on_publish_sent()?;
535 self.publishes.insert(req_id.into_inner(), sm);
536 let msg = ControlMessage::Publish(message::Publish {
537 request_id: req_id,
538 track_namespace,
539 track_name,
540 track_alias,
541 track_extensions,
542 parameters,
543 });
544 Ok((req_id, msg))
545 }
546
547 pub fn receive_publish_ok(&mut self, msg: &message::PublishOk) -> Result<(), EndpointError> {
550 let id = msg.request_id.into_inner();
551 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
552 sm.on_publish_ok()?;
553 Ok(())
554 }
555
556 pub fn send_publish_done(
559 &mut self,
560 request_id: VarInt,
561 status_code: VarInt,
562 stream_count: VarInt,
563 reason_phrase: Vec<u8>,
564 ) -> Result<ControlMessage, EndpointError> {
565 let id = request_id.into_inner();
566 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
567 sm.on_publish_done_sent()?;
568 Ok(ControlMessage::PublishDone(PublishDone {
569 request_id,
570 status_code,
571 stream_count,
572 reason_phrase,
573 }))
574 }
575
576 pub fn receive_request_ok(&mut self, msg: &RequestOk) -> Result<(), EndpointError> {
581 let id = msg.request_id.into_inner();
582 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
584 sm.on_subscribe_namespace_ok()?;
585 return Ok(());
586 }
587 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
589 sm.on_publish_namespace_ok()?;
590 return Ok(());
591 }
592 if let Some(sm) = self.track_statuses.get_mut(&id) {
594 sm.on_track_status_ok()?;
595 return Ok(());
596 }
597 Err(EndpointError::UnknownRequest(id))
598 }
599
600 pub fn receive_request_error(&mut self, msg: &RequestError) -> Result<(), EndpointError> {
602 let id = msg.request_id.into_inner();
603 if let Some(sm) = self.subscriptions.get_mut(&id) {
605 sm.on_subscribe_error()?;
606 return Ok(());
607 }
608 if let Some(sm) = self.fetches.get_mut(&id) {
610 sm.on_fetch_error()?;
611 return Ok(());
612 }
613 if let Some(sm) = self.publishes.get_mut(&id) {
615 sm.on_publish_error()?;
616 return Ok(());
617 }
618 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
620 sm.on_subscribe_namespace_error()?;
621 return Ok(());
622 }
623 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
625 sm.on_publish_namespace_error()?;
626 return Ok(());
627 }
628 if let Some(sm) = self.track_statuses.get_mut(&id) {
630 sm.on_track_status_error()?;
631 return Ok(());
632 }
633 Ok(())
635 }
636
637 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
641 match msg {
642 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
643 ControlMessage::MaxRequestId(ref m) => self.receive_max_request_id(m),
644 ControlMessage::RequestsBlocked(ref m) => self.receive_requests_blocked(m),
645 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
646 ControlMessage::RequestUpdate(ref m) => self.receive_request_update(m),
647 ControlMessage::PublishDone(ref m) => self.receive_publish_done(m),
648 ControlMessage::PublishOk(ref m) => self.receive_publish_ok(m),
649 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
650 ControlMessage::RequestOk(ref m) => self.receive_request_ok(m),
651 ControlMessage::RequestError(ref m) => self.receive_request_error(m),
652 ControlMessage::PublishNamespaceDone(ref m) => self.receive_publish_namespace_done(m),
653 _ => Ok(()),
654 }
655 }
656}