Handle out of sessions and out of exchanges
This commit is contained in:
parent
4c347c0c0b
commit
e171e33510
8 changed files with 360 additions and 171 deletions
|
@ -17,6 +17,8 @@
|
|||
|
||||
use core::{borrow::Borrow, cell::RefCell};
|
||||
|
||||
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
|
||||
|
||||
use crate::{
|
||||
acl::AclMgr,
|
||||
data_model::{
|
||||
|
@ -61,6 +63,8 @@ pub struct Matter<'a> {
|
|||
dev_att: &'a dyn DevAttDataFetcher,
|
||||
pub(crate) port: u16,
|
||||
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
|
||||
pub(crate) ephemeral: RefCell<Option<ExchangeCtx>>,
|
||||
pub(crate) ephemeral_mutex: Mutex<NoopRawMutex, ()>,
|
||||
pub session_mgr: RefCell<SessionMgr>, // Public for tests
|
||||
}
|
||||
|
||||
|
@ -108,6 +112,8 @@ impl<'a> Matter<'a> {
|
|||
dev_att,
|
||||
port,
|
||||
exchanges: RefCell::new(heapless::Vec::new()),
|
||||
ephemeral: RefCell::new(None),
|
||||
ephemeral_mutex: Mutex::new(()),
|
||||
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ pub enum ErrorCode {
|
|||
NoMemory,
|
||||
NoSession,
|
||||
NoSpace,
|
||||
NoSpaceExchanges,
|
||||
NoSpaceSessions,
|
||||
NoSpaceAckTable,
|
||||
NoSpaceRetransTable,
|
||||
NoTagFound,
|
||||
|
|
|
@ -96,7 +96,7 @@ impl<'a> Case<'a> {
|
|||
) -> Result<(), Error> {
|
||||
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
|
||||
|
||||
let status = {
|
||||
let result = {
|
||||
let fabric_mgr = self.fabric_mgr.borrow();
|
||||
|
||||
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
|
||||
|
@ -133,7 +133,7 @@ impl<'a> Case<'a> {
|
|||
|
||||
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
|
||||
error!("Certificate Chain doesn't match: {}", e);
|
||||
SCStatusCodes::InvalidParameter
|
||||
Err(SCStatusCodes::InvalidParameter)
|
||||
} else if let Err(e) = Case::validate_sigma3_sign(
|
||||
d.initiator_noc.0,
|
||||
d.initiator_icac.map(|a| a.0),
|
||||
|
@ -142,32 +142,35 @@ impl<'a> Case<'a> {
|
|||
case_session,
|
||||
) {
|
||||
error!("Sigma3 Signature doesn't match: {}", e);
|
||||
SCStatusCodes::InvalidParameter
|
||||
Err(SCStatusCodes::InvalidParameter)
|
||||
} else {
|
||||
// Only now do we add this message to the TT Hash
|
||||
let mut peer_catids: NocCatIds = Default::default();
|
||||
initiator_noc.get_cat_ids(&mut peer_catids);
|
||||
case_session.tt_hash.update(rx.as_slice())?;
|
||||
let clone_data = Case::get_session_clone_data(
|
||||
|
||||
Ok(Case::get_session_clone_data(
|
||||
fabric.ipk.op_key(),
|
||||
fabric.get_node_id(),
|
||||
initiator_noc.get_node_id()?,
|
||||
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
|
||||
case_session,
|
||||
&peer_catids,
|
||||
)?;
|
||||
|
||||
// TODO: Handle NoSpace
|
||||
exchange
|
||||
.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
|
||||
|
||||
SCStatusCodes::SessionEstablishmentSuccess
|
||||
)?)
|
||||
}
|
||||
} else {
|
||||
SCStatusCodes::NoSharedTrustRoots
|
||||
Err(SCStatusCodes::NoSharedTrustRoots)
|
||||
}
|
||||
};
|
||||
|
||||
let status = match result {
|
||||
Ok(clone_data) => {
|
||||
exchange.clone_session(tx, &clone_data).await?;
|
||||
SCStatusCodes::SessionEstablishmentSuccess
|
||||
}
|
||||
Err(status) => status,
|
||||
};
|
||||
|
||||
complete_with_status(exchange, tx, status, None).await
|
||||
}
|
||||
|
||||
|
@ -201,7 +204,7 @@ impl<'a> Case<'a> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
|
||||
let local_sessid = exchange.get_next_sess_id();
|
||||
case_session.peer_sessid = r.initiator_sessid;
|
||||
case_session.local_sessid = local_sessid;
|
||||
case_session.tt_hash.update(rx_buf)?;
|
||||
|
|
|
@ -78,8 +78,8 @@ pub fn create_sc_status_report(
|
|||
// the session will be closed soon
|
||||
GeneralCode::Success
|
||||
}
|
||||
SCStatusCodes::Busy
|
||||
| SCStatusCodes::InvalidParameter
|
||||
SCStatusCodes::Busy => GeneralCode::Busy,
|
||||
SCStatusCodes::InvalidParameter
|
||||
| SCStatusCodes::NoSharedTrustRoots
|
||||
| SCStatusCodes::SessionNotFound => GeneralCode::Failure,
|
||||
};
|
||||
|
|
|
@ -167,9 +167,9 @@ impl<'a> Pake<'a> {
|
|||
self.update_timeout(exchange, tx, true).await?;
|
||||
|
||||
let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
|
||||
let (status_code, ke) = spake2p.handle_cA(cA);
|
||||
let (status, ke) = spake2p.handle_cA(cA);
|
||||
|
||||
let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
|
||||
let result = if status == SCStatusCodes::SessionEstablishmentSuccess {
|
||||
// Get the keys
|
||||
let ke = ke.ok_or(ErrorCode::Invalid)?;
|
||||
let mut session_keys: [u8; 48] = [0; 48];
|
||||
|
@ -194,22 +194,22 @@ impl<'a> Pake<'a> {
|
|||
.att_challenge
|
||||
.copy_from_slice(&session_keys[32..48]);
|
||||
|
||||
// Queue a transport mgr request to add a new session
|
||||
Some(clone_data)
|
||||
Ok(clone_data)
|
||||
} else {
|
||||
None
|
||||
Err(status)
|
||||
};
|
||||
|
||||
if let Some(clone_data) = clone_data {
|
||||
// TODO: Handle NoSpace
|
||||
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
|
||||
|
||||
let status = match result {
|
||||
Ok(clone_data) => {
|
||||
exchange.clone_session(tx, &clone_data).await?;
|
||||
self.pase.borrow_mut().disable_pase_session(mdns)?;
|
||||
|
||||
SCStatusCodes::SessionEstablishmentSuccess
|
||||
}
|
||||
Err(status) => status,
|
||||
};
|
||||
|
||||
complete_with_status(exchange, tx, status_code, None).await?;
|
||||
|
||||
Ok(())
|
||||
complete_with_status(exchange, tx, status, None).await
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
|
@ -273,7 +273,7 @@ impl<'a> Pake<'a> {
|
|||
let mut our_random: [u8; 32] = [0; 32];
|
||||
(self.pase.borrow().rand)(&mut our_random);
|
||||
|
||||
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
|
||||
let local_sessid = exchange.get_next_sess_id();
|
||||
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
|
||||
spake2p.set_app_data(spake2p_data);
|
||||
|
||||
|
|
|
@ -25,6 +25,9 @@ use embassy_time::{Duration, Timer};
|
|||
|
||||
use log::{error, info, warn};
|
||||
|
||||
use crate::interaction_model::core::IMStatusCode;
|
||||
use crate::secure_channel::common::SCStatusCodes;
|
||||
use crate::secure_channel::status_report::{create_status_report, GeneralCode};
|
||||
use crate::utils::select::Notification;
|
||||
use crate::CommissioningData;
|
||||
use crate::{
|
||||
|
@ -41,6 +44,7 @@ use crate::{
|
|||
Matter,
|
||||
};
|
||||
|
||||
use super::exchange::SessionId;
|
||||
use super::{
|
||||
exchange::{
|
||||
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
|
||||
|
@ -97,7 +101,7 @@ impl RunBuffers {
|
|||
pub struct PacketBuffers {
|
||||
tx: [TxBuf; MAX_EXCHANGES],
|
||||
rx: [RxBuf; MAX_EXCHANGES],
|
||||
sx: [SxBuf; MAX_EXCHANGES],
|
||||
sx: [SxBuf; MAX_EXCHANGES + 1],
|
||||
}
|
||||
|
||||
impl PacketBuffers {
|
||||
|
@ -107,7 +111,7 @@ impl PacketBuffers {
|
|||
|
||||
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
|
||||
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
|
||||
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES];
|
||||
const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1];
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn new() -> Self {
|
||||
|
@ -266,7 +270,12 @@ impl<'a> Matter<'a> {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel));
|
||||
let mut rx = pin!(self.handle_rx_multiplex(
|
||||
rx_pipe,
|
||||
unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() },
|
||||
construction_notification,
|
||||
&channel,
|
||||
));
|
||||
|
||||
let result = select(&mut rx, select_slice(&mut handlers)).await;
|
||||
|
||||
|
@ -291,7 +300,7 @@ impl<'a> Matter<'a> {
|
|||
if data.chunk.is_none() {
|
||||
let mut tx = alloc!(Packet::new_tx(data.buf));
|
||||
|
||||
if self.pull_tx(&mut tx).await? {
|
||||
if self.pull_tx(&mut tx)? {
|
||||
data.chunk = Some(Chunk {
|
||||
start: tx.get_writebuf()?.get_start(),
|
||||
end: tx.get_writebuf()?.get_tail(),
|
||||
|
@ -315,12 +324,15 @@ impl<'a> Matter<'a> {
|
|||
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
|
||||
&'t self,
|
||||
rx_pipe: &Pipe<'_>,
|
||||
sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
|
||||
construction_notification: &'e Notification,
|
||||
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
't: 'e,
|
||||
{
|
||||
let mut sts_tx = alloc!(Packet::new_tx(sts_buf));
|
||||
|
||||
loop {
|
||||
info!("Transport: waiting for incoming packets");
|
||||
|
||||
|
@ -331,8 +343,9 @@ impl<'a> Matter<'a> {
|
|||
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
|
||||
rx.peer = chunk.addr;
|
||||
|
||||
if let Some(exchange_ctr) =
|
||||
self.process_rx(construction_notification, &mut rx)?
|
||||
if let Some(exchange_ctr) = self
|
||||
.process_rx(construction_notification, &mut rx, &mut sts_tx)
|
||||
.await?
|
||||
{
|
||||
let exchange_id = exchange_ctr.id().clone();
|
||||
|
||||
|
@ -444,24 +457,39 @@ impl<'a> Matter<'a> {
|
|||
self.session_mgr.borrow_mut().reset();
|
||||
}
|
||||
|
||||
pub fn process_rx<'r>(
|
||||
pub async fn process_rx<'r>(
|
||||
&'r self,
|
||||
construction_notification: &'r Notification,
|
||||
src_rx: &mut Packet<'_>,
|
||||
sts_tx: &mut Packet<'_>,
|
||||
) -> Result<Option<ExchangeCtr<'r>>, Error> {
|
||||
src_rx.plain_hdr_decode()?;
|
||||
|
||||
self.purge()?;
|
||||
|
||||
let mut exchanges = self.exchanges.borrow_mut();
|
||||
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) {
|
||||
Ok((ctx, new)) => (ctx, new),
|
||||
let (exchange_index, new) = loop {
|
||||
let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx);
|
||||
|
||||
match result {
|
||||
Err(e) => match e.code() {
|
||||
ErrorCode::Duplicate => {
|
||||
self.send_notification.signal(());
|
||||
return Ok(None);
|
||||
}
|
||||
_ => Err(e)?,
|
||||
// TODO: NoSession, NoExchange and others
|
||||
ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?,
|
||||
ErrorCode::NoSpaceExchanges => {
|
||||
self.send_busy(src_rx, sts_tx).await?;
|
||||
return Ok(None);
|
||||
}
|
||||
_ => break Err(e),
|
||||
},
|
||||
};
|
||||
other => break other,
|
||||
}
|
||||
}?;
|
||||
|
||||
let mut exchanges = self.exchanges.borrow_mut();
|
||||
let ctx = &mut exchanges[exchange_index];
|
||||
|
||||
src_rx.log("Got packet");
|
||||
|
||||
|
@ -516,6 +544,8 @@ impl<'a> Matter<'a> {
|
|||
ExchangeState::ExchangeRecv {
|
||||
rx, notification, ..
|
||||
} => {
|
||||
// TODO: Handle Busy status codes
|
||||
|
||||
let rx = unsafe { rx.as_mut() }.unwrap();
|
||||
rx.load(src_rx)?;
|
||||
|
||||
|
@ -572,12 +602,24 @@ impl<'a> Matter<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> {
|
||||
pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result<bool, Error> {
|
||||
self.purge()?;
|
||||
|
||||
let mut ephemeral = self.ephemeral.borrow_mut();
|
||||
let mut exchanges = self.exchanges.borrow_mut();
|
||||
|
||||
let ctx = exchanges.iter_mut().find(|ctx| {
|
||||
self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx)
|
||||
}
|
||||
|
||||
fn pull_tx_exchanges<'i, I>(
|
||||
&self,
|
||||
mut exchanges: I,
|
||||
dest_tx: &mut Packet,
|
||||
) -> Result<bool, Error>
|
||||
where
|
||||
I: Iterator<Item = &'i mut ExchangeCtx>,
|
||||
{
|
||||
let ctx = exchanges.find(|ctx| {
|
||||
matches!(
|
||||
&ctx.state,
|
||||
ExchangeState::Acknowledge { .. }
|
||||
|
@ -629,10 +671,15 @@ impl<'a> Matter<'a> {
|
|||
let tx = unsafe { tx.as_ref() }.unwrap();
|
||||
dest_tx.load(tx)?;
|
||||
|
||||
if dest_tx.is_reliable() {
|
||||
*state = ExchangeState::CompleteAcknowledge {
|
||||
_tx: tx as *const _,
|
||||
notification: *notification,
|
||||
};
|
||||
} else {
|
||||
unsafe { notification.as_ref() }.unwrap().signal(());
|
||||
ctx.state = ExchangeState::Closed;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -648,8 +695,6 @@ impl<'a> Matter<'a> {
|
|||
|
||||
if send {
|
||||
dest_tx.log("Sending packet");
|
||||
|
||||
self.pre_send(ctx, dest_tx)?;
|
||||
self.notify_changed();
|
||||
|
||||
return Ok(true);
|
||||
|
@ -675,13 +720,88 @@ impl<'a> Matter<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn post_recv<'r>(
|
||||
&self,
|
||||
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
|
||||
rx: &mut Packet<'_>,
|
||||
) -> Result<(&'r mut ExchangeCtx, bool), Error> {
|
||||
rx.plain_hdr_decode()?;
|
||||
pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
let sess_index = self.session_mgr.borrow().get_session_for_eviction();
|
||||
if let Some(sess_index) = sess_index {
|
||||
let ctx = {
|
||||
create_status_report(
|
||||
tx,
|
||||
GeneralCode::Success,
|
||||
PROTO_ID_SECURE_CHANNEL as _,
|
||||
SCStatusCodes::CloseSession as _,
|
||||
None,
|
||||
)?;
|
||||
|
||||
let mut session_mgr = self.session_mgr.borrow_mut();
|
||||
let session_id = session_mgr.mut_by_index(sess_index).unwrap().id();
|
||||
warn!("Evicting session: {:?}", session_id);
|
||||
|
||||
let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?;
|
||||
|
||||
session_mgr.remove(sess_index);
|
||||
|
||||
ctx
|
||||
};
|
||||
|
||||
self.send_ephemeral(ctx, tx).await
|
||||
} else {
|
||||
Err(ErrorCode::NoSpaceSessions.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
warn!("Sending Busy as all exchanges are occupied");
|
||||
|
||||
create_status_report(
|
||||
tx,
|
||||
GeneralCode::Busy,
|
||||
rx.get_proto_id() as _,
|
||||
if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
|
||||
SCStatusCodes::Busy as _
|
||||
} else {
|
||||
IMStatusCode::Busy as _
|
||||
},
|
||||
None, // TODO: ms
|
||||
)?;
|
||||
|
||||
let ctx = ExchangeCtx::prep_ephemeral(
|
||||
SessionId::load(rx),
|
||||
&mut self.session_mgr.borrow_mut(),
|
||||
Some(rx),
|
||||
tx,
|
||||
)?;
|
||||
|
||||
self.send_ephemeral(ctx, tx).await
|
||||
}
|
||||
|
||||
async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
let _guard = self.ephemeral_mutex.lock().await;
|
||||
|
||||
let notification = Notification::new();
|
||||
|
||||
let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||
|
||||
ctx.state = ExchangeState::Complete {
|
||||
tx,
|
||||
notification: ¬ification,
|
||||
};
|
||||
|
||||
*self.ephemeral.borrow_mut() = Some(ctx);
|
||||
|
||||
self.send_notification.signal(());
|
||||
|
||||
notification.wait().await;
|
||||
|
||||
*self.ephemeral.borrow_mut() = None;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assign_exchange(
|
||||
&self,
|
||||
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
|
||||
rx: &mut Packet<'_>,
|
||||
) -> Result<(usize, bool), Error> {
|
||||
// Get the session
|
||||
|
||||
let mut session_mgr = self.session_mgr.borrow_mut();
|
||||
|
@ -693,8 +813,7 @@ impl<'a> Matter<'a> {
|
|||
session.recv(self.epoch, rx)?;
|
||||
|
||||
// Get the exchange
|
||||
// TODO: Handle out of space
|
||||
let (exch, new) = Self::register(
|
||||
let (exchange_index, new) = Self::register(
|
||||
exchanges,
|
||||
ExchangeId::load(rx),
|
||||
Role::complementary(rx.proto.is_initiator()),
|
||||
|
@ -703,32 +822,9 @@ impl<'a> Matter<'a> {
|
|||
)?;
|
||||
|
||||
// Message Reliability Protocol
|
||||
exch.mrp.recv(rx, self.epoch)?;
|
||||
exchanges[exchange_index].mrp.recv(rx, self.epoch)?;
|
||||
|
||||
Ok((exch, new))
|
||||
}
|
||||
|
||||
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
|
||||
let mut session_mgr = self.session_mgr.borrow_mut();
|
||||
let sess_index = session_mgr
|
||||
.get(
|
||||
ctx.id.session_id.id,
|
||||
ctx.id.session_id.peer_addr,
|
||||
ctx.id.session_id.peer_nodeid,
|
||||
ctx.id.session_id.is_encrypted,
|
||||
)
|
||||
.ok_or(ErrorCode::NoSession)?;
|
||||
|
||||
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
||||
|
||||
tx.proto.exch_id = ctx.id.id;
|
||||
if ctx.role == Role::Initiator {
|
||||
tx.proto.set_initiator();
|
||||
}
|
||||
|
||||
session.pre_send(tx)?;
|
||||
ctx.mrp.pre_send(tx)?;
|
||||
session_mgr.send(sess_index, tx)
|
||||
Ok((exchange_index, new))
|
||||
}
|
||||
|
||||
fn register(
|
||||
|
@ -736,7 +832,7 @@ impl<'a> Matter<'a> {
|
|||
id: ExchangeId,
|
||||
role: Role,
|
||||
create_new: bool,
|
||||
) -> Result<(&mut ExchangeCtx, bool), Error> {
|
||||
) -> Result<(usize, bool), Error> {
|
||||
let exchange_index = exchanges
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
|
@ -745,7 +841,7 @@ impl<'a> Matter<'a> {
|
|||
if let Some(exchange_index) = exchange_index {
|
||||
let exchange = &mut exchanges[exchange_index];
|
||||
if exchange.role == role {
|
||||
Ok((exchange, false))
|
||||
Ok((exchange_index, false))
|
||||
} else {
|
||||
Err(ErrorCode::NoExchange.into())
|
||||
}
|
||||
|
@ -759,9 +855,11 @@ impl<'a> Matter<'a> {
|
|||
state: ExchangeState::Active,
|
||||
};
|
||||
|
||||
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?;
|
||||
exchanges
|
||||
.push(exchange)
|
||||
.map_err(|_| ErrorCode::NoSpaceExchanges)?;
|
||||
|
||||
Ok((exchanges.iter_mut().next_back().unwrap(), true))
|
||||
Ok((exchanges.len() - 1, true))
|
||||
} else {
|
||||
Err(ErrorCode::NoExchange.into())
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
acl::Accessor,
|
||||
error::{Error, ErrorCode},
|
||||
utils::select::Notification,
|
||||
utils::{epoch::Epoch, select::Notification},
|
||||
Matter,
|
||||
};
|
||||
|
||||
|
@ -9,7 +9,7 @@ use super::{
|
|||
mrp::ReliableMessage,
|
||||
network::Address,
|
||||
packet::Packet,
|
||||
session::{Session, SessionMgr},
|
||||
session::{CloneData, Session, SessionMgr},
|
||||
};
|
||||
|
||||
pub const MAX_EXCHANGES: usize = 8;
|
||||
|
@ -46,6 +46,101 @@ impl ExchangeCtx {
|
|||
) -> Option<&'r mut ExchangeCtx> {
|
||||
exchanges.iter_mut().find(|exchange| exchange.id == *id)
|
||||
}
|
||||
|
||||
pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self {
|
||||
Self {
|
||||
id: ExchangeId {
|
||||
id: if let Some(rx) = reply_to {
|
||||
rx.proto.exch_id
|
||||
} else {
|
||||
0
|
||||
},
|
||||
session_id: session_id.clone(),
|
||||
},
|
||||
role: if reply_to.is_some() {
|
||||
Role::Responder
|
||||
} else {
|
||||
Role::Initiator
|
||||
},
|
||||
mrp: ReliableMessage::new(),
|
||||
state: ExchangeState::Active,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn prep_ephemeral(
|
||||
session_id: SessionId,
|
||||
session_mgr: &mut SessionMgr,
|
||||
reply_to: Option<&Packet<'_>>,
|
||||
tx: &mut Packet<'_>,
|
||||
) -> Result<ExchangeCtx, Error> {
|
||||
let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to);
|
||||
|
||||
let sess_index = session_mgr.get(
|
||||
session_id.id,
|
||||
session_id.peer_addr,
|
||||
session_id.peer_nodeid,
|
||||
session_id.is_encrypted,
|
||||
);
|
||||
|
||||
let epoch = session_mgr.epoch;
|
||||
let rand = session_mgr.rand;
|
||||
|
||||
if let Some(rx) = reply_to {
|
||||
ctx.mrp.recv(rx, epoch)?;
|
||||
} else {
|
||||
tx.proto.set_initiator();
|
||||
}
|
||||
|
||||
tx.unset_reliable();
|
||||
|
||||
if let Some(sess_index) = sess_index {
|
||||
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
||||
ctx.pre_send_sess(session, tx, epoch)?;
|
||||
} else {
|
||||
let mut session =
|
||||
Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand);
|
||||
ctx.pre_send_sess(&mut session, tx, epoch)?;
|
||||
}
|
||||
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
pub(crate) fn pre_send(
|
||||
&mut self,
|
||||
session_mgr: &mut SessionMgr,
|
||||
tx: &mut Packet,
|
||||
) -> Result<(), Error> {
|
||||
let epoch = session_mgr.epoch;
|
||||
|
||||
let sess_index = session_mgr
|
||||
.get(
|
||||
self.id.session_id.id,
|
||||
self.id.session_id.peer_addr,
|
||||
self.id.session_id.peer_nodeid,
|
||||
self.id.session_id.is_encrypted,
|
||||
)
|
||||
.ok_or(ErrorCode::NoSession)?;
|
||||
|
||||
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
||||
|
||||
self.pre_send_sess(session, tx, epoch)
|
||||
}
|
||||
|
||||
pub(crate) fn pre_send_sess(
|
||||
&mut self,
|
||||
session: &mut Session,
|
||||
tx: &mut Packet,
|
||||
epoch: Epoch,
|
||||
) -> Result<(), Error> {
|
||||
tx.proto.exch_id = self.id.id;
|
||||
if self.role == Role::Initiator {
|
||||
tx.proto.set_initiator();
|
||||
}
|
||||
|
||||
session.pre_send(tx)?;
|
||||
self.mrp.pre_send(tx)?;
|
||||
session.send(epoch, tx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -192,15 +287,6 @@ impl<'a> Exchange<'a> {
|
|||
self.with_session_mut(|sess| f(sess))
|
||||
}
|
||||
|
||||
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
|
||||
where
|
||||
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
|
||||
{
|
||||
let mut session_mgr = self.matter.session_mgr.borrow_mut();
|
||||
|
||||
f(&mut session_mgr)
|
||||
}
|
||||
|
||||
pub async fn acknowledge(&mut self) -> Result<(), Error> {
|
||||
let wait = self.with_ctx_mut(|_self, ctx| {
|
||||
if !matches!(ctx.state, ExchangeState::Active) {
|
||||
|
@ -226,8 +312,12 @@ impl<'a> Exchange<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||
pub async fn exchange(
|
||||
&mut self,
|
||||
tx: &mut Packet<'_>,
|
||||
rx: &mut Packet<'_>,
|
||||
) -> Result<(), Error> {
|
||||
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
|
||||
|
||||
self.with_ctx_mut(|_self, ctx| {
|
||||
|
@ -235,6 +325,9 @@ impl<'a> Exchange<'a> {
|
|||
Err(ErrorCode::NoExchange)?;
|
||||
}
|
||||
|
||||
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
|
||||
ctx.pre_send(&mut session_mgr, tx)?;
|
||||
|
||||
ctx.state = ExchangeState::ExchangeSend {
|
||||
tx: tx as *const _,
|
||||
rx: rx as *mut _,
|
||||
|
@ -250,18 +343,21 @@ impl<'a> Exchange<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> {
|
||||
pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
self.send_complete(tx).await
|
||||
}
|
||||
|
||||
pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> {
|
||||
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||
pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||
|
||||
self.with_ctx_mut(|_self, ctx| {
|
||||
if !matches!(ctx.state, ExchangeState::Active) {
|
||||
Err(ErrorCode::NoExchange)?;
|
||||
}
|
||||
|
||||
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
|
||||
ctx.pre_send(&mut session_mgr, tx)?;
|
||||
|
||||
ctx.state = ExchangeState::Complete {
|
||||
tx: tx as *const _,
|
||||
notification: &_self.notification as *const _,
|
||||
|
@ -276,6 +372,31 @@ impl<'a> Exchange<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get_next_sess_id(&mut self) -> u16 {
|
||||
self.matter.session_mgr.borrow_mut().get_next_sess_id()
|
||||
}
|
||||
|
||||
pub(crate) async fn clone_session(
|
||||
&mut self,
|
||||
tx: &mut Packet<'_>,
|
||||
clone_data: &CloneData,
|
||||
) -> Result<usize, Error> {
|
||||
loop {
|
||||
let result = self
|
||||
.matter
|
||||
.session_mgr
|
||||
.borrow_mut()
|
||||
.clone_session(clone_data);
|
||||
|
||||
match result {
|
||||
Err(err) if err.code() == ErrorCode::NoSpaceSessions => {
|
||||
self.matter.evict_session(tx).await?
|
||||
}
|
||||
other => break other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
|
||||
where
|
||||
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,
|
||||
|
|
|
@ -19,13 +19,13 @@ use crate::data_model::sdm::noc::NocData;
|
|||
use crate::utils::epoch::Epoch;
|
||||
use crate::utils::rand::Rand;
|
||||
use core::fmt;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
use core::time::Duration;
|
||||
|
||||
use crate::{error::*, transport::plain_hdr};
|
||||
use log::info;
|
||||
|
||||
use super::dedup::RxCtrState;
|
||||
use super::exchange::SessionId;
|
||||
use super::{network::Address, packet::Packet};
|
||||
|
||||
pub const MAX_CAT_IDS_PER_NOC: usize = 3;
|
||||
|
@ -151,6 +151,15 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> SessionId {
|
||||
SessionId {
|
||||
id: self.local_sess_id,
|
||||
peer_addr: self.peer_addr,
|
||||
peer_nodeid: self.peer_nodeid,
|
||||
is_encrypted: self.is_encrypted(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_noc_data(&mut self, data: NocData) {
|
||||
self.data = Some(data);
|
||||
}
|
||||
|
@ -251,7 +260,7 @@ impl Session {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
|
||||
pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
|
||||
self.last_use = epoch();
|
||||
|
||||
tx.proto_encode(
|
||||
|
@ -291,8 +300,8 @@ pub const MAX_SESSIONS: usize = 16;
|
|||
pub struct SessionMgr {
|
||||
next_sess_id: u16,
|
||||
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
|
||||
epoch: Epoch,
|
||||
rand: Rand,
|
||||
pub(crate) epoch: Epoch,
|
||||
pub(crate) rand: Rand,
|
||||
}
|
||||
|
||||
impl SessionMgr {
|
||||
|
@ -327,7 +336,11 @@ impl SessionMgr {
|
|||
}
|
||||
|
||||
// Ensure the currently selected id doesn't match any existing session
|
||||
if self.get_with_id(next_sess_id).is_none() {
|
||||
if self.sessions.iter().all(|sess| {
|
||||
sess.as_ref()
|
||||
.map(|sess| sess.get_local_sess_id() != next_sess_id)
|
||||
.unwrap_or(true)
|
||||
}) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -381,12 +394,12 @@ impl SessionMgr {
|
|||
} else if self.sessions.len() < MAX_SESSIONS {
|
||||
self.sessions
|
||||
.push(Some(session))
|
||||
.map_err(|_| ErrorCode::NoSpace)
|
||||
.map_err(|_| ErrorCode::NoSpaceSessions)
|
||||
.unwrap();
|
||||
|
||||
Ok(self.sessions.len() - 1)
|
||||
} else {
|
||||
Err(ErrorCode::NoSpace.into())
|
||||
Err(ErrorCode::NoSpaceSessions.into())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -419,14 +432,6 @@ impl SessionMgr {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn get_with_id(&mut self, sess_id: u16) -> Option<SessionHandle> {
|
||||
let index = self
|
||||
.sessions
|
||||
.iter_mut()
|
||||
.position(|x| x.as_ref().map(|s| s.local_sess_id) == Some(sess_id))?;
|
||||
Some(self.get_session_handle(index))
|
||||
}
|
||||
|
||||
pub fn get_or_add(
|
||||
&mut self,
|
||||
sess_id: u16,
|
||||
|
@ -472,13 +477,6 @@ impl SessionMgr {
|
|||
.ok_or(ErrorCode::NoSession)?
|
||||
.send(self.epoch, tx)
|
||||
}
|
||||
|
||||
pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle {
|
||||
SessionHandle {
|
||||
sess_mgr: self,
|
||||
sess_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SessionMgr {
|
||||
|
@ -492,45 +490,6 @@ impl fmt::Display for SessionMgr {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct SessionHandle<'a> {
|
||||
pub(crate) sess_mgr: &'a mut SessionMgr,
|
||||
sess_idx: usize,
|
||||
}
|
||||
|
||||
impl<'a> SessionHandle<'a> {
|
||||
pub fn session(&self) -> &Session {
|
||||
self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn session_mut(&mut self) -> &mut Session {
|
||||
self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap()
|
||||
}
|
||||
|
||||
pub fn reserve_new_sess_id(&mut self) -> u16 {
|
||||
self.sess_mgr.get_next_sess_id()
|
||||
}
|
||||
|
||||
pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> {
|
||||
self.sess_mgr.send(self.sess_idx, tx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for SessionHandle<'a> {
|
||||
type Target = Session;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
// There is no other option but to panic if this is None
|
||||
self.session()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DerefMut for SessionHandle<'a> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
// There is no other option but to panic if this is None
|
||||
self.session_mut()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
@ -545,12 +504,12 @@ mod tests {
|
|||
fn test_next_sess_id_doesnt_reuse() {
|
||||
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||
let mut sess = sm.get_session_handle(sess_idx);
|
||||
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||
sess.set_local_sess_id(1);
|
||||
assert_eq!(sm.get_next_sess_id(), 2);
|
||||
assert_eq!(sm.get_next_sess_id(), 3);
|
||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||
let mut sess = sm.get_session_handle(sess_idx);
|
||||
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||
sess.set_local_sess_id(4);
|
||||
assert_eq!(sm.get_next_sess_id(), 5);
|
||||
}
|
||||
|
@ -559,7 +518,7 @@ mod tests {
|
|||
fn test_next_sess_id_overflows() {
|
||||
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||
let mut sess = sm.get_session_handle(sess_idx);
|
||||
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||
sess.set_local_sess_id(1);
|
||||
assert_eq!(sm.get_next_sess_id(), 2);
|
||||
sm.next_sess_id = 65534;
|
||||
|
|
Loading…
Add table
Reference in a new issue