1use std::collections::HashMap;
2
3use crate::draft07::fetch::{FetchError, FetchStateMachine};
4use crate::draft07::namespace::{
5 AnnounceStateMachine, NamespaceError, SubscribeAnnouncesStateMachine,
6};
7use crate::draft07::session::setup::{self, SetupError};
8use crate::draft07::session::state::{SessionError, SessionState, SessionStateMachine};
9use crate::draft07::session::subscribe_id::{SubscribeIdAllocator, SubscribeIdError};
10use crate::draft07::subscription::{SubscriptionError, SubscriptionStateMachine};
11use crate::draft07::track_status::{TrackStatusError, TrackStatusStateMachine};
12use moqtap_codec::draft07::message::{
13 self, Announce, AnnounceCancel, AnnounceError, AnnounceOk, ClientSetup, ControlMessage, Fetch,
14 FetchCancel, GoAway, MaxSubscribeId, ServerSetup, Subscribe, SubscribeAnnounces,
15 SubscribeAnnouncesError, SubscribeAnnouncesOk, SubscribeDone, SubscribeError, SubscribeOk,
16 SubscribeUpdate, TrackStatus, TrackStatusRequest, Unannounce, Unsubscribe,
17 UnsubscribeAnnounces,
18};
19use moqtap_codec::kvp::{KeyValuePair, KvpValue};
20use moqtap_codec::types::*;
21use moqtap_codec::varint::VarInt;
22
23type NamespaceKey = Vec<Vec<u8>>;
25
26type TrackKey = (Vec<Vec<u8>>, Vec<u8>);
28
29#[derive(Debug, thiserror::Error)]
31pub enum EndpointError {
32 #[error("session error: {0}")]
34 Session(#[from] SessionError),
35 #[error("subscribe ID error: {0}")]
37 SubscribeId(#[from] SubscribeIdError),
38 #[error("subscription error: {0}")]
40 Subscription(#[from] SubscriptionError),
41 #[error("fetch error: {0}")]
43 Fetch(#[from] FetchError),
44 #[error("namespace error: {0}")]
46 Namespace(#[from] NamespaceError),
47 #[error("track status error: {0}")]
49 TrackStatus(#[from] TrackStatusError),
50 #[error("setup error: {0}")]
52 Setup(#[from] SetupError),
53 #[error("unknown subscribe ID: {0}")]
55 UnknownSubscribe(u64),
56 #[error("unknown namespace")]
58 UnknownNamespace,
59 #[error("unknown track status request")]
61 UnknownTrackStatus,
62 #[error("session not active")]
64 NotActive,
65 #[error("session is draining, no new requests allowed")]
67 Draining,
68}
69
70pub struct Endpoint {
74 session: SessionStateMachine,
75 subscribe_ids: SubscribeIdAllocator,
76 advertised_max_id: u64,
78 subscriptions: HashMap<u64, SubscriptionStateMachine>,
79 fetches: HashMap<u64, FetchStateMachine>,
80 subscribe_announces: HashMap<NamespaceKey, SubscribeAnnouncesStateMachine>,
81 announces: HashMap<NamespaceKey, AnnounceStateMachine>,
82 track_statuses: HashMap<TrackKey, TrackStatusStateMachine>,
83 negotiated_version: Option<VarInt>,
84 offered_versions: Vec<VarInt>,
85 goaway_uri: Option<Vec<u8>>,
86}
87
88impl Default for Endpoint {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Endpoint {
95 pub fn new() -> Self {
97 Self {
98 session: SessionStateMachine::new(),
99 subscribe_ids: SubscribeIdAllocator::new(),
100 advertised_max_id: 0,
101 subscriptions: HashMap::new(),
102 fetches: HashMap::new(),
103 subscribe_announces: HashMap::new(),
104 announces: HashMap::new(),
105 track_statuses: HashMap::new(),
106 negotiated_version: None,
107 offered_versions: Vec::new(),
108 goaway_uri: None,
109 }
110 }
111
112 pub fn session_state(&self) -> SessionState {
116 self.session.state()
117 }
118
119 pub fn negotiated_version(&self) -> Option<VarInt> {
121 self.negotiated_version
122 }
123
124 pub fn goaway_uri(&self) -> Option<&[u8]> {
126 self.goaway_uri.as_deref()
127 }
128
129 pub fn is_blocked(&self) -> bool {
131 self.subscribe_ids.is_blocked()
132 }
133
134 pub fn active_subscription_count(&self) -> usize {
136 self.subscriptions.len()
137 }
138
139 pub fn active_fetch_count(&self) -> usize {
141 self.fetches.len()
142 }
143
144 pub fn active_subscribe_announces_count(&self) -> usize {
146 self.subscribe_announces.len()
147 }
148
149 pub fn active_announce_count(&self) -> usize {
151 self.announces.len()
152 }
153
154 pub fn active_track_status_count(&self) -> usize {
156 self.track_statuses.len()
157 }
158
159 pub fn connect(&mut self) -> Result<(), EndpointError> {
163 self.session.on_connect()?;
164 Ok(())
165 }
166
167 pub fn close(&mut self) -> Result<(), EndpointError> {
169 self.session.on_close()?;
170 Ok(())
171 }
172
173 pub fn send_client_setup(
177 &mut self,
178 versions: Vec<VarInt>,
179 parameters: Vec<KeyValuePair>,
180 ) -> Result<ControlMessage, EndpointError> {
181 self.offered_versions = versions.clone();
182 let msg = ClientSetup { supported_versions: versions, parameters };
183 setup::validate_client_setup(&msg)?;
184 Ok(ControlMessage::ClientSetup(msg))
185 }
186
187 pub fn receive_server_setup(&mut self, msg: &ServerSetup) -> Result<(), EndpointError> {
191 setup::validate_server_setup(msg)?;
192 let version = setup::negotiate_version(&self.offered_versions, msg.selected_version)?;
193 self.negotiated_version = Some(version);
194 self.session.on_setup_complete()?;
195 for param in &msg.parameters {
197 if param.key == VarInt::from_u64(0x02).unwrap() {
198 if let KvpValue::Varint(v) = ¶m.value {
199 self.subscribe_ids.update_max(v.into_inner())?;
200 }
201 }
202 }
203 Ok(())
204 }
205
206 pub fn receive_client_setup_and_respond(
210 &mut self,
211 client_setup: &ClientSetup,
212 selected_version: VarInt,
213 ) -> Result<ControlMessage, EndpointError> {
214 setup::validate_client_setup(client_setup)?;
215 let version = setup::negotiate_version(&client_setup.supported_versions, selected_version)?;
216 self.negotiated_version = Some(version);
217 self.session.on_setup_complete()?;
218 let msg = ServerSetup { selected_version: version, parameters: vec![] };
219 Ok(ControlMessage::ServerSetup(msg))
220 }
221
222 pub fn receive_max_subscribe_id(&mut self, msg: &MaxSubscribeId) -> Result<(), EndpointError> {
226 self.subscribe_ids.update_max(msg.subscribe_id.into_inner())?;
227 Ok(())
228 }
229
230 pub fn send_max_subscribe_id(
233 &mut self,
234 max_id: VarInt,
235 ) -> Result<ControlMessage, EndpointError> {
236 let new_val = max_id.into_inner();
237 if new_val <= self.advertised_max_id && self.advertised_max_id > 0 {
238 return Err(EndpointError::SubscribeId(SubscribeIdError::Decreased(
239 self.advertised_max_id,
240 new_val,
241 )));
242 }
243 self.advertised_max_id = new_val;
244 Ok(ControlMessage::MaxSubscribeId(MaxSubscribeId { subscribe_id: max_id }))
245 }
246
247 pub fn receive_goaway(&mut self, msg: &GoAway) -> Result<(), EndpointError> {
251 self.session.on_goaway()?;
252 self.goaway_uri = Some(msg.new_session_uri.clone());
253 Ok(())
254 }
255
256 fn require_active_or_err(&self) -> Result<(), EndpointError> {
259 match self.session.state() {
260 SessionState::Active => Ok(()),
261 SessionState::Draining => Err(EndpointError::Draining),
262 _ => Err(EndpointError::NotActive),
263 }
264 }
265
266 #[allow(clippy::too_many_arguments)]
269 pub fn subscribe(
270 &mut self,
271 track_alias: VarInt,
272 track_namespace: TrackNamespace,
273 track_name: Vec<u8>,
274 subscriber_priority: u8,
275 group_order: GroupOrder,
276 filter_type: FilterType,
277 ) -> Result<(VarInt, ControlMessage), EndpointError> {
278 self.require_active_or_err()?;
279 let sub_id = self.subscribe_ids.allocate()?;
280
281 let mut sm = SubscriptionStateMachine::new();
282 sm.on_subscribe_sent()?;
283 self.subscriptions.insert(sub_id.into_inner(), sm);
284
285 let msg = ControlMessage::Subscribe(Subscribe {
286 subscribe_id: sub_id,
287 track_alias,
288 track_namespace,
289 track_name,
290 subscriber_priority,
291 group_order,
292 filter_type,
293 start_location: None,
294 end_group: None,
295 end_object: None,
296 parameters: vec![],
297 });
298 Ok((sub_id, msg))
299 }
300
301 pub fn receive_subscribe_ok(&mut self, msg: &SubscribeOk) -> Result<(), EndpointError> {
303 let id = msg.subscribe_id.into_inner();
304 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
305 sm.on_subscribe_ok()?;
306 Ok(())
307 }
308
309 pub fn receive_subscribe_error(&mut self, msg: &SubscribeError) -> Result<(), EndpointError> {
311 let id = msg.subscribe_id.into_inner();
312 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
313 sm.on_subscribe_error()?;
314 Ok(())
315 }
316
317 pub fn unsubscribe(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
319 let id = subscribe_id.into_inner();
320 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
321 sm.on_unsubscribe()?;
322 Ok(ControlMessage::Unsubscribe(Unsubscribe { subscribe_id }))
323 }
324
325 pub fn receive_subscribe_update(&mut self, msg: &SubscribeUpdate) -> Result<(), EndpointError> {
327 let id = msg.subscribe_id.into_inner();
328 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
329 sm.on_subscribe_update()?;
330 Ok(())
331 }
332
333 pub fn receive_subscribe_done(&mut self, msg: &SubscribeDone) -> Result<(), EndpointError> {
335 let id = msg.subscribe_id.into_inner();
336 let sm = self.subscriptions.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
337 sm.on_subscribe_done()?;
338 Ok(())
339 }
340
341 #[allow(clippy::too_many_arguments)]
345 pub fn fetch(
346 &mut self,
347 track_namespace: TrackNamespace,
348 track_name: Vec<u8>,
349 subscriber_priority: u8,
350 group_order: GroupOrder,
351 start_group: VarInt,
352 start_object: VarInt,
353 end_group: VarInt,
354 end_object: VarInt,
355 ) -> Result<(VarInt, ControlMessage), EndpointError> {
356 self.require_active_or_err()?;
357 let sub_id = self.subscribe_ids.allocate()?;
358
359 let mut sm = FetchStateMachine::new();
360 sm.on_fetch_sent()?;
361 self.fetches.insert(sub_id.into_inner(), sm);
362
363 let msg = ControlMessage::Fetch(Fetch {
364 subscribe_id: sub_id,
365 track_namespace,
366 track_name,
367 subscriber_priority,
368 group_order,
369 start_group,
370 start_object,
371 end_group,
372 end_object,
373 parameters: vec![],
374 });
375 Ok((sub_id, msg))
376 }
377
378 pub fn receive_fetch_ok(&mut self, msg: &message::FetchOk) -> Result<(), EndpointError> {
380 let id = msg.subscribe_id.into_inner();
381 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
382 sm.on_fetch_ok()?;
383 Ok(())
384 }
385
386 pub fn receive_fetch_error(&mut self, msg: &message::FetchError) -> Result<(), EndpointError> {
388 let id = msg.subscribe_id.into_inner();
389 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
390 sm.on_fetch_error()?;
391 Ok(())
392 }
393
394 pub fn fetch_cancel(&mut self, subscribe_id: VarInt) -> Result<ControlMessage, EndpointError> {
396 let id = subscribe_id.into_inner();
397 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
398 sm.on_fetch_cancel()?;
399 Ok(ControlMessage::FetchCancel(FetchCancel { subscribe_id }))
400 }
401
402 pub fn on_fetch_stream_fin(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
404 let id = subscribe_id.into_inner();
405 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
406 sm.on_stream_fin()?;
407 Ok(())
408 }
409
410 pub fn on_fetch_stream_reset(&mut self, subscribe_id: VarInt) -> Result<(), EndpointError> {
412 let id = subscribe_id.into_inner();
413 let sm = self.fetches.get_mut(&id).ok_or(EndpointError::UnknownSubscribe(id))?;
414 sm.on_stream_reset()?;
415 Ok(())
416 }
417
418 pub fn subscribe_announces(
422 &mut self,
423 track_namespace_prefix: TrackNamespace,
424 ) -> Result<ControlMessage, EndpointError> {
425 self.require_active_or_err()?;
426 let key = track_namespace_prefix.0.clone();
427 let mut sm = SubscribeAnnouncesStateMachine::new();
428 sm.on_subscribe_announces_sent()?;
429 self.subscribe_announces.insert(key, sm);
430 Ok(ControlMessage::SubscribeAnnounces(SubscribeAnnounces {
431 track_namespace_prefix,
432 parameters: vec![],
433 }))
434 }
435
436 pub fn receive_subscribe_announces_ok(
438 &mut self,
439 msg: &SubscribeAnnouncesOk,
440 ) -> Result<(), EndpointError> {
441 let sm = self
442 .subscribe_announces
443 .get_mut(&msg.track_namespace_prefix.0)
444 .ok_or(EndpointError::UnknownNamespace)?;
445 sm.on_subscribe_announces_ok()?;
446 Ok(())
447 }
448
449 pub fn receive_subscribe_announces_error(
451 &mut self,
452 msg: &SubscribeAnnouncesError,
453 ) -> Result<(), EndpointError> {
454 let sm = self
455 .subscribe_announces
456 .get_mut(&msg.track_namespace_prefix.0)
457 .ok_or(EndpointError::UnknownNamespace)?;
458 sm.on_subscribe_announces_error()?;
459 Ok(())
460 }
461
462 pub fn unsubscribe_announces(
464 &mut self,
465 track_namespace_prefix: TrackNamespace,
466 ) -> Result<ControlMessage, EndpointError> {
467 let sm = self
468 .subscribe_announces
469 .get_mut(&track_namespace_prefix.0)
470 .ok_or(EndpointError::UnknownNamespace)?;
471 sm.on_unsubscribe_announces()?;
472 Ok(ControlMessage::UnsubscribeAnnounces(UnsubscribeAnnounces { track_namespace_prefix }))
473 }
474
475 pub fn announce(
479 &mut self,
480 track_namespace: TrackNamespace,
481 ) -> Result<ControlMessage, EndpointError> {
482 self.require_active_or_err()?;
483 let key = track_namespace.0.clone();
484 let mut sm = AnnounceStateMachine::new();
485 sm.on_announce_sent()?;
486 self.announces.insert(key, sm);
487 Ok(ControlMessage::Announce(Announce { track_namespace, parameters: vec![] }))
488 }
489
490 pub fn receive_announce_ok(&mut self, msg: &AnnounceOk) -> Result<(), EndpointError> {
492 let sm = self
493 .announces
494 .get_mut(&msg.track_namespace.0)
495 .ok_or(EndpointError::UnknownNamespace)?;
496 sm.on_announce_ok()?;
497 Ok(())
498 }
499
500 pub fn receive_announce_error(&mut self, msg: &AnnounceError) -> Result<(), EndpointError> {
502 let sm = self
503 .announces
504 .get_mut(&msg.track_namespace.0)
505 .ok_or(EndpointError::UnknownNamespace)?;
506 sm.on_announce_error()?;
507 Ok(())
508 }
509
510 pub fn receive_announce_cancel(&mut self, msg: &AnnounceCancel) -> Result<(), EndpointError> {
512 let sm = self
513 .announces
514 .get_mut(&msg.track_namespace.0)
515 .ok_or(EndpointError::UnknownNamespace)?;
516 sm.on_announce_cancel()?;
517 Ok(())
518 }
519
520 pub fn unannounce(
522 &mut self,
523 track_namespace: TrackNamespace,
524 ) -> Result<ControlMessage, EndpointError> {
525 let sm =
526 self.announces.get_mut(&track_namespace.0).ok_or(EndpointError::UnknownNamespace)?;
527 sm.on_unannounce()?;
528 Ok(ControlMessage::Unannounce(Unannounce { track_namespace }))
529 }
530
531 pub fn track_status_request(
535 &mut self,
536 track_namespace: TrackNamespace,
537 track_name: Vec<u8>,
538 ) -> Result<ControlMessage, EndpointError> {
539 self.require_active_or_err()?;
540 let key = (track_namespace.0.clone(), track_name.clone());
541 let mut sm = TrackStatusStateMachine::new();
542 sm.on_track_status_request_sent()?;
543 self.track_statuses.insert(key, sm);
544 Ok(ControlMessage::TrackStatusRequest(TrackStatusRequest { track_namespace, track_name }))
545 }
546
547 pub fn receive_track_status(&mut self, msg: &TrackStatus) -> Result<(), EndpointError> {
549 let key = (msg.track_namespace.0.clone(), msg.track_name.clone());
550 let sm = self.track_statuses.get_mut(&key).ok_or(EndpointError::UnknownTrackStatus)?;
551 sm.on_track_status()?;
552 Ok(())
553 }
554
555 pub fn receive_message(&mut self, msg: ControlMessage) -> Result<(), EndpointError> {
559 match msg {
560 ControlMessage::GoAway(ref m) => self.receive_goaway(m),
561 ControlMessage::MaxSubscribeId(ref m) => self.receive_max_subscribe_id(m),
562 ControlMessage::SubscribeOk(ref m) => self.receive_subscribe_ok(m),
563 ControlMessage::SubscribeError(ref m) => self.receive_subscribe_error(m),
564 ControlMessage::SubscribeUpdate(ref m) => self.receive_subscribe_update(m),
565 ControlMessage::SubscribeDone(ref m) => self.receive_subscribe_done(m),
566 ControlMessage::FetchOk(ref m) => self.receive_fetch_ok(m),
567 ControlMessage::FetchError(ref m) => self.receive_fetch_error(m),
568 ControlMessage::SubscribeAnnouncesOk(ref m) => self.receive_subscribe_announces_ok(m),
569 ControlMessage::SubscribeAnnouncesError(ref m) => {
570 self.receive_subscribe_announces_error(m)
571 }
572 ControlMessage::AnnounceOk(ref m) => self.receive_announce_ok(m),
573 ControlMessage::AnnounceError(ref m) => self.receive_announce_error(m),
574 ControlMessage::AnnounceCancel(ref m) => self.receive_announce_cancel(m),
575 ControlMessage::TrackStatus(ref m) => self.receive_track_status(m),
576 _ => Ok(()),
577 }
578 }
579}