server/event/
scalar.rs

1// SPDX-FileCopyrightText: 2023 Foundation Devices, Inc. <hello@foundation.xyz>
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4use std::{any::type_name, marker::PhantomData};
5
6use rkyv::bytecheck::CheckBytes;
7use whence::WhenceExt;
8use xous_ipc::{XousDeserializer, XousValidator};
9
10use crate::{utils, Error, EventSubscriptionMessage, ScalarCodec, Server, ServerContext};
11
12/// Handle for a single event subscriber
13pub struct ScalarEventSubscriber<M>
14where
15    M: ScalarEvent,
16{
17    pid: xous::PID,
18    cid: xous::CID,
19    msg_id: xous::MessageId,
20    cancel_msg_id: xous::MessageId,
21    _phantom: PhantomData<M>,
22}
23
24impl<M> core::fmt::Debug for ScalarEventSubscriber<M>
25where
26    M: ScalarEvent,
27{
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("ScalarEventSubscriber").field("pid", &self.pid).finish()
30    }
31}
32
33impl<M> ScalarEventSubscriber<M>
34where
35    M: ScalarEvent,
36{
37    /// Send the event to the subscriber.
38    /// Can be used in an IRQ handler context.
39    pub fn send(&self, msg: &M) -> Result<xous::Result, xous::Error> {
40        let msg = xous::Message::Scalar(utils::scalar_to_message(msg, self.msg_id));
41        xous::try_send_message(self.cid, msg)
42    }
43
44    pub fn pid(&self) -> xous::PID { self.pid }
45
46    pub fn cid(&self) -> xous::CID { self.cid }
47}
48
49impl<M> Drop for ScalarEventSubscriber<M>
50where
51    M: ScalarEvent,
52{
53    fn drop(&mut self) {
54        if let Err(e) =
55            xous::send_message(self.cid, super::cancellation_message(self.msg_id, self.cancel_msg_id))
56        {
57            log::debug!("Error sending cancellation message {self:?}: {e:?}")
58        }
59        if let Err(e) = xous::disconnect(self.cid) {
60            log::error!("Error disconnecting {self:?}: {e:?}")
61        }
62    }
63}
64
65/// A message which can be serialized and deserialized using scalar encoding.
66pub trait ScalarEvent: ScalarCodec {}
67
68impl<M> ScalarEvent for M where M: ScalarCodec {}
69
70pub trait ScalarSubscription
71where
72    Self: crate::MessageId + crate::ArchiveCodec,
73    <Self::Error as rkyv::Archive>::Archived:
74        rkyv::Deserialize<Self::Error, XousDeserializer> + for<'a> CheckBytes<XousValidator<'a>>,
75    <Result<(), Self::Error> as rkyv::Archive>::Archived:
76        rkyv::Deserialize<Result<(), Self::Error>, XousDeserializer> + for<'a> CheckBytes<XousValidator<'a>>,
77{
78    type Event: ScalarEvent;
79    type Error: super::SubscriptionError;
80}
81
82/// A [`ScalarSubscription`] handler.
83pub trait ScalarEventSubscriptionHandler<M>
84where
85    Self: Server,
86    M: ScalarSubscription,
87{
88    /// Handle the subscription.
89    ///
90    /// The `subscriber` parameter can be used to store the subscriber info and send events to them
91    /// later. Once their subscription is not used, the object can be dropped.
92    fn handle(
93        &mut self,
94        msg: M,
95        subscriber: ScalarEventSubscriber<M::Event>,
96        context: &mut ServerContext<Self>,
97    ) -> Result<(), M::Error>;
98}
99
100/// Handler for an incoming [`ScalarEvent`]
101pub trait ScalarEventHandler<M>
102where
103    Self: Server,
104    M: ScalarEvent,
105{
106    fn handle(&mut self, msg: M, sender: xous::PID, context: &mut ServerContext<Self>);
107}
108
109/// Message handler, used by ServerMessages::messages()
110pub fn handle_scalar_subscription<M, S>(
111    handler: &mut S,
112    raw: xous::MessageEnvelope,
113    context: &mut ServerContext<S>,
114) where
115    M: ScalarSubscription + 'static,
116    S: ScalarEventSubscriptionHandler<M>,
117    <M as rkyv::Archive>::Archived:
118        rkyv::Deserialize<M, XousDeserializer> + for<'a> CheckBytes<XousValidator<'a>>,
119{
120    let pid = raw.sender.pid().unwrap();
121    if let Err(e) = try_handle_scalar_subscription(pid, handler, raw, context) {
122        log::warn!("archive sub handle error (PID {pid}) for {}: {e}", type_name::<M>());
123    }
124}
125
126fn try_handle_scalar_subscription<M, S>(
127    pid: xous::PID,
128    handler: &mut S,
129    mut raw: xous::MessageEnvelope,
130    context: &mut ServerContext<S>,
131) -> whence::Result<(), Error>
132where
133    M: ScalarSubscription + 'static,
134    S: ScalarEventSubscriptionHandler<M>,
135    <M as rkyv::Archive>::Archived:
136        rkyv::Deserialize<M, XousDeserializer> + for<'a> CheckBytes<XousValidator<'a>>,
137{
138    let mut buf = utils::extract_borrow_mut_message(&mut raw).whence()?;
139    let msg: EventSubscriptionMessage<M> = buf.to_original().whence()?;
140    let res = handler.handle(
141        msg.msg,
142        ScalarEventSubscriber::<M::Event> {
143            pid,
144            msg_id: msg.msg_id,
145            cancel_msg_id: msg.cancel_msg_id,
146            cid: msg.cid,
147            _phantom: PhantomData,
148        },
149        context,
150    );
151    buf.replace(&res).whence()?;
152    Ok(())
153}
154
155pub fn decode_scalar_event<M>(raw: xous::MessageEnvelope) -> M
156where
157    M: ScalarEvent,
158{
159    try_decode_scalar_event(raw).unwrap()
160}
161
162pub fn try_decode_scalar_event<M>(mut raw: xous::MessageEnvelope) -> whence::Result<M, crate::Error>
163where
164    M: ScalarEvent,
165{
166    let scalar = utils::extract_scalar_message(&mut raw).whence()?;
167    Ok(M::from_scalar(scalar))
168}
169
170pub(crate) fn scalar_event_handler<M, S>(
171    handler: &mut S,
172    raw: xous::MessageEnvelope,
173    context: &mut ServerContext<S>,
174) where
175    M: ScalarEvent,
176    S: ScalarEventHandler<M>,
177{
178    let sender = raw.sender.pid().unwrap();
179    let msg = decode_scalar_event::<M>(raw);
180    handler.handle(msg, sender, context);
181}
182
183/// Subscribe to a [`ScalarEvent`] event.
184///
185/// # Arguments
186///
187/// * `cid` - The connection ID to the event sending server.
188/// * `sid` - The server ID of the event receiving server.
189///
190/// # Returns
191///
192/// A tuple containing two unique message IDs (to this process) for the incoming events:
193/// - The first ID is for the event message.
194/// - The second ID is for the cancellation message.
195pub fn subscribe_scalar<M>(cid: xous::CID, msg: M, sid: xous::SID) -> Result<(usize, usize), M::Error>
196where
197    M: ScalarSubscription + 'static,
198{
199    try_subscribe_scalar(cid, msg, sid).unwrap()
200}
201
202pub fn try_subscribe_scalar<M>(
203    cid: xous::CID,
204    msg: M,
205    sid: xous::SID,
206) -> whence::Result<Result<(usize, usize), M::Error>, crate::Error>
207where
208    M: ScalarSubscription + 'static,
209{
210    let msg_id = crate::next_dynamic_message_id();
211    let cancel_msg_id = crate::next_dynamic_message_id();
212    let pid = xous::get_remote_pid(cid).whence()?;
213    let cid_remote = xous::connect_for_process(pid, sid).whence()?;
214    xous::allow_messages_on_connection(pid, cid_remote, msg_id..(cancel_msg_id + 1)).whence()?;
215    let msg = EventSubscriptionMessage { cid: cid_remote, msg_id, cancel_msg_id, msg };
216    let result = msg.send_scalar(cid)?;
217    Ok(result.map(|_| (msg_id, cancel_msg_id)))
218}
219
220/// A list of scalar event subscribers.
221pub struct ScalarSubList<T: ScalarCodec> {
222    inner: Vec<ScalarEventSubscriber<T>>,
223}
224
225impl<T: ScalarCodec> Default for ScalarSubList<T> {
226    fn default() -> Self { Self { inner: Default::default() } }
227}
228
229impl<T: ScalarCodec> ScalarSubList<T> {
230    pub fn push(&mut self, sub: ScalarEventSubscriber<T>) { self.inner.push(sub); }
231
232    pub fn send(&mut self, msg: &T) { self.inner.retain(|sub| sub.send(msg).is_ok()) }
233
234    pub fn send_nowait(&mut self, msg: &T) {
235        self.inner.retain(|sub| match sub.send(msg) {
236            Ok(_) => true,
237            Err(xous::Error::ServerQueueFull) => {
238                log::warn!("scalar event send_nowait error for pid {} {}", sub.pid(), type_name::<T>());
239                true
240            }
241            Err(_) => false,
242        })
243    }
244
245    pub fn remove_cid(&mut self, cid: xous::CID) { self.inner.retain(|s| s.cid() != cid) }
246}