1#![allow(missing_docs)]
2use std::collections::HashMap;
23
24use crate::draft17::fetch::{FetchError, FetchStateMachine};
25use crate::draft17::namespace::{
26 NamespaceError, PublishNamespaceStateMachine, SubscribeNamespaceStateMachine,
27};
28use crate::draft17::publish::{PublishError as PublishFlowError, PublishStateMachine};
29use crate::draft17::session::request_id::{RequestIdAllocator, RequestIdError, Role};
30use crate::draft17::session::setup::{self, SetupError};
31use crate::draft17::session::state::{SessionError, SessionState, SessionStateMachine};
32use crate::draft17::subscription::{SubscriptionError, SubscriptionStateMachine};
33use crate::draft17::track_status::{TrackStatusError, TrackStatusStateMachine};
34use moqtap_codec::draft17::message::{
35 self, ControlMessage, Fetch, FetchPayload, FetchType, GoAway, Publish, PublishBlocked,
36 PublishDone, PublishNamespace, RequestError, RequestOk, RequestUpdate, Setup, Subscribe,
37 SubscribeNamespace, SubscribeOk,
38};
39use moqtap_codec::kvp::KeyValuePair;
40use moqtap_codec::types::*;
41use moqtap_codec::varint::VarInt;
42
43#[derive(Debug, thiserror::Error)]
45pub enum EndpointError {
46 #[error("session error: {0}")]
47 Session(#[from] SessionError),
48 #[error("request ID error: {0}")]
49 RequestId(#[from] RequestIdError),
50 #[error("subscription error: {0}")]
51 Subscription(#[from] SubscriptionError),
52 #[error("fetch error: {0}")]
53 Fetch(#[from] FetchError),
54 #[error("namespace error: {0}")]
55 Namespace(#[from] NamespaceError),
56 #[error("track status error: {0}")]
57 TrackStatus(#[from] TrackStatusError),
58 #[error("publish flow error: {0}")]
59 PublishFlow(#[from] PublishFlowError),
60 #[error("setup error: {0}")]
61 Setup(#[from] SetupError),
62 #[error("unknown request ID: {0}")]
63 UnknownRequest(u64),
64 #[error(
65 "response message received on control stream; d17 responses belong on bidi request streams"
66 )]
67 ResponseOnControlStream,
68 #[error("session not active")]
69 NotActive,
70 #[error("session is draining, no new requests allowed")]
71 Draining,
72}
73
74pub struct Endpoint {
75 role: Role,
76 session: SessionStateMachine,
77 request_ids: RequestIdAllocator,
78 subscriptions: HashMap<u64, SubscriptionStateMachine>,
79 fetches: HashMap<u64, FetchStateMachine>,
80 subscribe_namespaces: HashMap<u64, SubscribeNamespaceStateMachine>,
81 publish_namespaces: HashMap<u64, PublishNamespaceStateMachine>,
82 track_statuses: HashMap<u64, TrackStatusStateMachine>,
83 publishes: HashMap<u64, PublishStateMachine>,
84 goaway_uri: Option<Vec<u8>>,
85}
86
87impl Endpoint {
88 pub fn new(role: Role) -> Self {
89 Self {
90 role,
91 session: SessionStateMachine::new(),
92 request_ids: RequestIdAllocator::new(role),
93 subscriptions: HashMap::new(),
94 fetches: HashMap::new(),
95 subscribe_namespaces: HashMap::new(),
96 publish_namespaces: HashMap::new(),
97 track_statuses: HashMap::new(),
98 publishes: HashMap::new(),
99 goaway_uri: None,
100 }
101 }
102
103 pub fn role(&self) -> Role {
104 self.role
105 }
106
107 pub fn session_state(&self) -> SessionState {
108 self.session.state()
109 }
110
111 pub fn goaway_uri(&self) -> Option<&[u8]> {
112 self.goaway_uri.as_deref()
113 }
114
115 pub fn active_subscription_count(&self) -> usize {
116 self.subscriptions.len()
117 }
118
119 pub fn active_fetch_count(&self) -> usize {
120 self.fetches.len()
121 }
122
123 pub fn active_subscribe_namespace_count(&self) -> usize {
124 self.subscribe_namespaces.len()
125 }
126
127 pub fn active_publish_namespace_count(&self) -> usize {
128 self.publish_namespaces.len()
129 }
130
131 pub fn active_track_status_count(&self) -> usize {
132 self.track_statuses.len()
133 }
134
135 pub fn active_publish_count(&self) -> usize {
136 self.publishes.len()
137 }
138
139 pub fn connect(&mut self) -> Result<(), EndpointError> {
142 self.session.on_connect()?;
143 Ok(())
144 }
145
146 pub fn close(&mut self) -> Result<(), EndpointError> {
147 self.session.on_close()?;
148 Ok(())
149 }
150
151 pub fn send_setup(
156 &mut self,
157 options: Vec<KeyValuePair>,
158 ) -> Result<ControlMessage, EndpointError> {
159 let msg = Setup { options };
160 setup::validate_setup(&msg)?;
161 Ok(ControlMessage::Setup(msg))
162 }
163
164 pub fn receive_setup(&mut self, msg: &Setup) -> Result<(), EndpointError> {
166 setup::validate_setup(msg)?;
167 self.session.on_setup_complete()?;
168 Ok(())
169 }
170
171 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
174 self.session.on_goaway()?;
175 self.goaway_uri = Some(msg.new_session_uri.clone());
176 Ok(())
177 }
178
179 fn delta() -> VarInt {
185 VarInt::from_u64(0).unwrap()
186 }
187
188 fn require_active_or_err(&self) -> Result<(), EndpointError> {
189 match self.session.state() {
190 SessionState::Active => Ok(()),
191 SessionState::Draining => Err(EndpointError::Draining),
192 _ => Err(EndpointError::NotActive),
193 }
194 }
195
196 pub fn subscribe(
199 &mut self,
200 track_namespace: TrackNamespace,
201 track_name: Vec<u8>,
202 parameters: Vec<KeyValuePair>,
203 ) -> Result<(VarInt, ControlMessage), EndpointError> {
204 self.require_active_or_err()?;
205 let req_id = self.request_ids.allocate()?;
206
207 let mut sm = SubscriptionStateMachine::new();
208 sm.on_subscribe_sent()?;
209 self.subscriptions.insert(req_id.into_inner(), sm);
210
211 let msg = ControlMessage::Subscribe(Subscribe {
212 request_id: req_id,
213 required_request_id_delta: Self::delta(),
214 track_namespace,
215 track_name,
216 parameters,
217 });
218 Ok((req_id, msg))
219 }
220
221 pub fn receive_subscribe_ok(
225 &mut self,
226 request_id: VarInt,
227 _msg: &SubscribeOk,
228 ) -> Result<(), EndpointError> {
229 let id = request_id.into_inner();
230 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
231 sm.on_subscribe_ok()?;
232 Ok(())
233 }
234
235 pub fn receive_request_update(&mut self, msg: &RequestUpdate) -> Result<(), EndpointError> {
238 let id = msg.request_id.into_inner();
239 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
240 sm.on_subscribe_update()?;
241 Ok(())
242 }
243
244 pub fn receive_publish_done(
248 &mut self,
249 request_id: VarInt,
250 _msg: &PublishDone,
251 ) -> Result<(), EndpointError> {
252 let id = request_id.into_inner();
253 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
254 sm.on_publish_done()?;
255 Ok(())
256 }
257
258 pub fn fetch(
261 &mut self,
262 track_namespace: TrackNamespace,
263 track_name: Vec<u8>,
264 start_group: VarInt,
265 start_object: VarInt,
266 end_group: VarInt,
267 end_object: VarInt,
268 ) -> Result<(VarInt, ControlMessage), EndpointError> {
269 self.require_active_or_err()?;
270 let req_id = self.request_ids.allocate()?;
271
272 let mut sm = FetchStateMachine::new();
273 sm.on_fetch_sent()?;
274 self.fetches.insert(req_id.into_inner(), sm);
275
276 let msg = ControlMessage::Fetch(Fetch {
277 request_id: req_id,
278 required_request_id_delta: Self::delta(),
279 fetch_type: FetchType::Standalone,
280 fetch_payload: FetchPayload::Standalone {
281 track_namespace,
282 track_name,
283 start_group,
284 start_object,
285 end_group,
286 end_object,
287 },
288 parameters: vec![],
289 });
290 Ok((req_id, msg))
291 }
292
293 pub fn joining_fetch(
294 &mut self,
295 joining_request_id: VarInt,
296 joining_start: VarInt,
297 ) -> Result<(VarInt, ControlMessage), EndpointError> {
298 self.require_active_or_err()?;
299 let req_id = self.request_ids.allocate()?;
300
301 let mut sm = FetchStateMachine::new();
302 sm.on_fetch_sent()?;
303 self.fetches.insert(req_id.into_inner(), sm);
304
305 let msg = ControlMessage::Fetch(Fetch {
306 request_id: req_id,
307 required_request_id_delta: Self::delta(),
308 fetch_type: FetchType::RelativeJoining,
309 fetch_payload: FetchPayload::Joining { joining_request_id, joining_start },
310 parameters: vec![],
311 });
312 Ok((req_id, msg))
313 }
314
315 pub fn receive_fetch_ok(
316 &mut self,
317 request_id: VarInt,
318 _msg: &message::FetchOk,
319 ) -> Result<(), EndpointError> {
320 let id = request_id.into_inner();
321 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
322 sm.on_fetch_ok()?;
323 Ok(())
324 }
325
326 pub fn on_fetch_stream_fin(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
327 let id = request_id.into_inner();
328 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
329 sm.on_stream_fin()?;
330 Ok(())
331 }
332
333 pub fn on_fetch_stream_reset(&mut self, request_id: VarInt) -> Result<(), EndpointError> {
334 let id = request_id.into_inner();
335 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
336 sm.on_stream_reset()?;
337 Ok(())
338 }
339
340 pub fn subscribe_namespace(
343 &mut self,
344 namespace_prefix: TrackNamespace,
345 subscribe_options: VarInt,
346 parameters: Vec<KeyValuePair>,
347 ) -> Result<(VarInt, ControlMessage), EndpointError> {
348 self.require_active_or_err()?;
349 let req_id = self.request_ids.allocate()?;
350
351 let mut sm = SubscribeNamespaceStateMachine::new();
352 sm.on_subscribe_namespace_sent()?;
353 self.subscribe_namespaces.insert(req_id.into_inner(), sm);
354
355 let msg = ControlMessage::SubscribeNamespace(SubscribeNamespace {
356 request_id: req_id,
357 required_request_id_delta: Self::delta(),
358 namespace_prefix,
359 subscribe_options,
360 parameters,
361 });
362 Ok((req_id, msg))
363 }
364
365 pub fn publish_namespace(
368 &mut self,
369 track_namespace: TrackNamespace,
370 parameters: Vec<KeyValuePair>,
371 ) -> Result<(VarInt, ControlMessage), EndpointError> {
372 self.require_active_or_err()?;
373 let req_id = self.request_ids.allocate()?;
374
375 let mut sm = PublishNamespaceStateMachine::new();
376 sm.on_publish_namespace_sent()?;
377 self.publish_namespaces.insert(req_id.into_inner(), sm);
378
379 let msg = ControlMessage::PublishNamespace(PublishNamespace {
380 request_id: req_id,
381 required_request_id_delta: Self::delta(),
382 track_namespace,
383 parameters,
384 });
385 Ok((req_id, msg))
386 }
387
388 pub fn track_status(
391 &mut self,
392 track_namespace: TrackNamespace,
393 track_name: Vec<u8>,
394 parameters: Vec<KeyValuePair>,
395 ) -> Result<(VarInt, ControlMessage), EndpointError> {
396 self.require_active_or_err()?;
397 let req_id = self.request_ids.allocate()?;
398 let mut sm = TrackStatusStateMachine::new();
399 sm.on_track_status_sent()?;
400 self.track_statuses.insert(req_id.into_inner(), sm);
401
402 let msg = ControlMessage::TrackStatus(message::TrackStatus {
403 request_id: req_id,
404 required_request_id_delta: Self::delta(),
405 track_namespace,
406 track_name,
407 parameters,
408 });
409 Ok((req_id, msg))
410 }
411
412 pub fn publish(
415 &mut self,
416 track_namespace: TrackNamespace,
417 track_name: Vec<u8>,
418 track_alias: VarInt,
419 parameters: Vec<KeyValuePair>,
420 track_properties: Vec<KeyValuePair>,
421 ) -> Result<(VarInt, ControlMessage), EndpointError> {
422 self.require_active_or_err()?;
423 let req_id = self.request_ids.allocate()?;
424 let mut sm = PublishStateMachine::new();
425 sm.on_publish_sent()?;
426 self.publishes.insert(req_id.into_inner(), sm);
427
428 let msg = ControlMessage::Publish(Publish {
429 request_id: req_id,
430 required_request_id_delta: Self::delta(),
431 track_namespace,
432 track_name,
433 track_alias,
434 parameters,
435 track_properties,
436 });
437 Ok((req_id, msg))
438 }
439
440 pub fn receive_publish_ok(
441 &mut self,
442 request_id: VarInt,
443 _msg: &message::PublishOk,
444 ) -> Result<(), EndpointError> {
445 let id = request_id.into_inner();
446 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
447 sm.on_publish_ok()?;
448 Ok(())
449 }
450
451 pub fn send_publish_done(
452 &mut self,
453 request_id: VarInt,
454 status_code: VarInt,
455 stream_count: VarInt,
456 reason_phrase: Vec<u8>,
457 ) -> Result<ControlMessage, EndpointError> {
458 let id = request_id.into_inner();
459 let sm = self.publishes.get_mut(&id).ok_or(EndpointError::UnknownRequest(id))?;
460 sm.on_publish_done_sent()?;
461 Ok(ControlMessage::PublishDone(PublishDone { status_code, stream_count, reason_phrase }))
462 }
463
464 pub fn receive_request_ok(
470 &mut self,
471 request_id: VarInt,
472 _msg: &RequestOk,
473 ) -> Result<(), EndpointError> {
474 let id = request_id.into_inner();
475 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
476 sm.on_subscribe_namespace_ok()?;
477 return Ok(());
478 }
479 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
480 sm.on_publish_namespace_ok()?;
481 return Ok(());
482 }
483 if let Some(sm) = self.track_statuses.get_mut(&id) {
484 sm.on_track_status_ok()?;
485 return Ok(());
486 }
487 Err(EndpointError::UnknownRequest(id))
488 }
489
490 pub fn receive_request_error(
493 &mut self,
494 request_id: VarInt,
495 _msg: &RequestError,
496 ) -> Result<(), EndpointError> {
497 let id = request_id.into_inner();
498 if let Some(sm) = self.subscriptions.get_mut(&id) {
499 sm.on_subscribe_error()?;
500 return Ok(());
501 }
502 if let Some(sm) = self.fetches.get_mut(&id) {
503 sm.on_fetch_error()?;
504 return Ok(());
505 }
506 if let Some(sm) = self.publishes.get_mut(&id) {
507 sm.on_publish_error()?;
508 return Ok(());
509 }
510 if let Some(sm) = self.subscribe_namespaces.get_mut(&id) {
511 sm.on_subscribe_namespace_error()?;
512 return Ok(());
513 }
514 if let Some(sm) = self.publish_namespaces.get_mut(&id) {
515 sm.on_publish_namespace_error()?;
516 return Ok(());
517 }
518 if let Some(sm) = self.track_statuses.get_mut(&id) {
519 sm.on_track_status_error()?;
520 return Ok(());
521 }
522 Err(EndpointError::UnknownRequest(id))
523 }
524
525 pub fn receive_namespace(&mut self, _msg: &message::Namespace) -> Result<(), EndpointError> {
530 Ok(())
531 }
532
533 pub fn receive_namespace_done(
535 &mut self,
536 _msg: &message::NamespaceDone,
537 ) -> Result<(), EndpointError> {
538 Ok(())
539 }
540
541 pub fn receive_publish_blocked(&mut self, _msg: &PublishBlocked) -> Result<(), EndpointError> {
543 Ok(())
544 }
545
546 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
552 match msg {
553 ControlMessage::Setup(ref m) => self.receive_setup(m),
554 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
555 ControlMessage::RequestUpdate(ref m) => self.receive_request_update(m),
556 ControlMessage::Namespace(ref m) => self.receive_namespace(m),
557 ControlMessage::NamespaceDone(ref m) => self.receive_namespace_done(m),
558 ControlMessage::PublishBlocked(ref m) => self.receive_publish_blocked(m),
559 ControlMessage::SubscribeOk(_)
560 | ControlMessage::PublishDone(_)
561 | ControlMessage::PublishOk(_)
562 | ControlMessage::FetchOk(_)
563 | ControlMessage::RequestOk(_)
564 | ControlMessage::RequestError(_) => Err(EndpointError::ResponseOnControlStream),
565 _ => Ok(()),
566 }
567 }
568
569 pub fn receive_response_on_stream(
573 &mut self,
574 request_id: VarInt,
575 msg: ControlMessage,
576 ) -> Result<(), EndpointError> {
577 match msg {
578 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(request_id, m),
579 ControlMessage::PublishDone(ref m) => self.receive_publish_done(request_id, m),
580 ControlMessage::PublishOk(ref m) => self.receive_publish_ok(request_id, m),
581 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(request_id, m),
582 ControlMessage::RequestOk(ref m) => self.receive_request_ok(request_id, m),
583 ControlMessage::RequestError(ref m) => self.receive_request_error(request_id, m),
584 _ => Err(EndpointError::ResponseOnControlStream),
585 }
586 }
587}