Handle out of sessions and out of exchanges

This commit is contained in:
ivmarkov 2023-07-31 17:34:21 +00:00
parent 4c347c0c0b
commit e171e33510
8 changed files with 360 additions and 171 deletions

View file

@ -17,6 +17,8 @@
use core::{borrow::Borrow, cell::RefCell};
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
use crate::{
acl::AclMgr,
data_model::{
@ -61,6 +63,8 @@ pub struct Matter<'a> {
dev_att: &'a dyn DevAttDataFetcher,
pub(crate) port: u16,
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
pub(crate) ephemeral: RefCell<Option<ExchangeCtx>>,
pub(crate) ephemeral_mutex: Mutex<NoopRawMutex, ()>,
pub session_mgr: RefCell<SessionMgr>, // Public for tests
}
@ -108,6 +112,8 @@ impl<'a> Matter<'a> {
dev_att,
port,
exchanges: RefCell::new(heapless::Vec::new()),
ephemeral: RefCell::new(None),
ephemeral_mutex: Mutex::new(()),
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
}
}

View file

@ -47,6 +47,8 @@ pub enum ErrorCode {
NoMemory,
NoSession,
NoSpace,
NoSpaceExchanges,
NoSpaceSessions,
NoSpaceAckTable,
NoSpaceRetransTable,
NoTagFound,

View file

@ -96,7 +96,7 @@ impl<'a> Case<'a> {
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
let status = {
let result = {
let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
@ -133,7 +133,7 @@ impl<'a> Case<'a> {
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
error!("Certificate Chain doesn't match: {}", e);
SCStatusCodes::InvalidParameter
Err(SCStatusCodes::InvalidParameter)
} else if let Err(e) = Case::validate_sigma3_sign(
d.initiator_noc.0,
d.initiator_icac.map(|a| a.0),
@ -142,32 +142,35 @@ impl<'a> Case<'a> {
case_session,
) {
error!("Sigma3 Signature doesn't match: {}", e);
SCStatusCodes::InvalidParameter
Err(SCStatusCodes::InvalidParameter)
} else {
// Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data(
Ok(Case::get_session_clone_data(
fabric.ipk.op_key(),
fabric.get_node_id(),
initiator_noc.get_node_id()?,
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
case_session,
&peer_catids,
)?;
// TODO: Handle NoSpace
exchange
.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
SCStatusCodes::SessionEstablishmentSuccess
)?)
}
} else {
SCStatusCodes::NoSharedTrustRoots
Err(SCStatusCodes::NoSharedTrustRoots)
}
};
let status = match result {
Ok(clone_data) => {
exchange.clone_session(tx, &clone_data).await?;
SCStatusCodes::SessionEstablishmentSuccess
}
Err(status) => status,
};
complete_with_status(exchange, tx, status, None).await
}
@ -201,7 +204,7 @@ impl<'a> Case<'a> {
return Ok(());
}
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let local_sessid = exchange.get_next_sess_id();
case_session.peer_sessid = r.initiator_sessid;
case_session.local_sessid = local_sessid;
case_session.tt_hash.update(rx_buf)?;

View file

@ -78,8 +78,8 @@ pub fn create_sc_status_report(
// the session will be closed soon
GeneralCode::Success
}
SCStatusCodes::Busy
| SCStatusCodes::InvalidParameter
SCStatusCodes::Busy => GeneralCode::Busy,
SCStatusCodes::InvalidParameter
| SCStatusCodes::NoSharedTrustRoots
| SCStatusCodes::SessionNotFound => GeneralCode::Failure,
};

View file

@ -167,9 +167,9 @@ impl<'a> Pake<'a> {
self.update_timeout(exchange, tx, true).await?;
let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let (status_code, ke) = spake2p.handle_cA(cA);
let (status, ke) = spake2p.handle_cA(cA);
let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
let result = if status == SCStatusCodes::SessionEstablishmentSuccess {
// Get the keys
let ke = ke.ok_or(ErrorCode::Invalid)?;
let mut session_keys: [u8; 48] = [0; 48];
@ -194,22 +194,22 @@ impl<'a> Pake<'a> {
.att_challenge
.copy_from_slice(&session_keys[32..48]);
// Queue a transport mgr request to add a new session
Some(clone_data)
Ok(clone_data)
} else {
None
Err(status)
};
if let Some(clone_data) = clone_data {
// TODO: Handle NoSpace
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
let status = match result {
Ok(clone_data) => {
exchange.clone_session(tx, &clone_data).await?;
self.pase.borrow_mut().disable_pase_session(mdns)?;
self.pase.borrow_mut().disable_pase_session(mdns)?;
}
SCStatusCodes::SessionEstablishmentSuccess
}
Err(status) => status,
};
complete_with_status(exchange, tx, status_code, None).await?;
Ok(())
complete_with_status(exchange, tx, status, None).await
}
#[allow(non_snake_case)]
@ -273,7 +273,7 @@ impl<'a> Pake<'a> {
let mut our_random: [u8; 32] = [0; 32];
(self.pase.borrow().rand)(&mut our_random);
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let local_sessid = exchange.get_next_sess_id();
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
spake2p.set_app_data(spake2p_data);

View file

@ -25,6 +25,9 @@ use embassy_time::{Duration, Timer};
use log::{error, info, warn};
use crate::interaction_model::core::IMStatusCode;
use crate::secure_channel::common::SCStatusCodes;
use crate::secure_channel::status_report::{create_status_report, GeneralCode};
use crate::utils::select::Notification;
use crate::CommissioningData;
use crate::{
@ -41,6 +44,7 @@ use crate::{
Matter,
};
use super::exchange::SessionId;
use super::{
exchange::{
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
@ -97,7 +101,7 @@ impl RunBuffers {
pub struct PacketBuffers {
tx: [TxBuf; MAX_EXCHANGES],
rx: [RxBuf; MAX_EXCHANGES],
sx: [SxBuf; MAX_EXCHANGES],
sx: [SxBuf; MAX_EXCHANGES + 1],
}
impl PacketBuffers {
@ -107,7 +111,7 @@ impl PacketBuffers {
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES];
const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1];
#[inline(always)]
pub const fn new() -> Self {
@ -266,7 +270,12 @@ impl<'a> Matter<'a> {
.unwrap();
}
let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel));
let mut rx = pin!(self.handle_rx_multiplex(
rx_pipe,
unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() },
construction_notification,
&channel,
));
let result = select(&mut rx, select_slice(&mut handlers)).await;
@ -291,7 +300,7 @@ impl<'a> Matter<'a> {
if data.chunk.is_none() {
let mut tx = alloc!(Packet::new_tx(data.buf));
if self.pull_tx(&mut tx).await? {
if self.pull_tx(&mut tx)? {
data.chunk = Some(Chunk {
start: tx.get_writebuf()?.get_start(),
end: tx.get_writebuf()?.get_tail(),
@ -315,12 +324,15 @@ impl<'a> Matter<'a> {
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
&'t self,
rx_pipe: &Pipe<'_>,
sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
construction_notification: &'e Notification,
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
) -> Result<(), Error>
where
't: 'e,
{
let mut sts_tx = alloc!(Packet::new_tx(sts_buf));
loop {
info!("Transport: waiting for incoming packets");
@ -331,8 +343,9 @@ impl<'a> Matter<'a> {
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
rx.peer = chunk.addr;
if let Some(exchange_ctr) =
self.process_rx(construction_notification, &mut rx)?
if let Some(exchange_ctr) = self
.process_rx(construction_notification, &mut rx, &mut sts_tx)
.await?
{
let exchange_id = exchange_ctr.id().clone();
@ -444,24 +457,39 @@ impl<'a> Matter<'a> {
self.session_mgr.borrow_mut().reset();
}
pub fn process_rx<'r>(
pub async fn process_rx<'r>(
&'r self,
construction_notification: &'r Notification,
src_rx: &mut Packet<'_>,
sts_tx: &mut Packet<'_>,
) -> Result<Option<ExchangeCtr<'r>>, Error> {
src_rx.plain_hdr_decode()?;
self.purge()?;
let (exchange_index, new) = loop {
let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx);
match result {
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
// TODO: NoSession, NoExchange and others
ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?,
ErrorCode::NoSpaceExchanges => {
self.send_busy(src_rx, sts_tx).await?;
return Ok(None);
}
_ => break Err(e),
},
other => break other,
}
}?;
let mut exchanges = self.exchanges.borrow_mut();
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) {
Ok((ctx, new)) => (ctx, new),
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
_ => Err(e)?,
},
};
let ctx = &mut exchanges[exchange_index];
src_rx.log("Got packet");
@ -516,6 +544,8 @@ impl<'a> Matter<'a> {
ExchangeState::ExchangeRecv {
rx, notification, ..
} => {
// TODO: Handle Busy status codes
let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?;
@ -572,12 +602,24 @@ impl<'a> Matter<'a> {
Ok(())
}
pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> {
pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result<bool, Error> {
self.purge()?;
let mut ephemeral = self.ephemeral.borrow_mut();
let mut exchanges = self.exchanges.borrow_mut();
let ctx = exchanges.iter_mut().find(|ctx| {
self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx)
}
fn pull_tx_exchanges<'i, I>(
&self,
mut exchanges: I,
dest_tx: &mut Packet,
) -> Result<bool, Error>
where
I: Iterator<Item = &'i mut ExchangeCtx>,
{
let ctx = exchanges.find(|ctx| {
matches!(
&ctx.state,
ExchangeState::Acknowledge { .. }
@ -629,10 +671,15 @@ impl<'a> Matter<'a> {
let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?;
*state = ExchangeState::CompleteAcknowledge {
_tx: tx as *const _,
notification: *notification,
};
if dest_tx.is_reliable() {
*state = ExchangeState::CompleteAcknowledge {
_tx: tx as *const _,
notification: *notification,
};
} else {
unsafe { notification.as_ref() }.unwrap().signal(());
ctx.state = ExchangeState::Closed;
}
true
}
@ -648,8 +695,6 @@ impl<'a> Matter<'a> {
if send {
dest_tx.log("Sending packet");
self.pre_send(ctx, dest_tx)?;
self.notify_changed();
return Ok(true);
@ -675,13 +720,88 @@ impl<'a> Matter<'a> {
Ok(())
}
fn post_recv<'r>(
&self,
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
rx: &mut Packet<'_>,
) -> Result<(&'r mut ExchangeCtx, bool), Error> {
rx.plain_hdr_decode()?;
pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
let sess_index = self.session_mgr.borrow().get_session_for_eviction();
if let Some(sess_index) = sess_index {
let ctx = {
create_status_report(
tx,
GeneralCode::Success,
PROTO_ID_SECURE_CHANNEL as _,
SCStatusCodes::CloseSession as _,
None,
)?;
let mut session_mgr = self.session_mgr.borrow_mut();
let session_id = session_mgr.mut_by_index(sess_index).unwrap().id();
warn!("Evicting session: {:?}", session_id);
let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?;
session_mgr.remove(sess_index);
ctx
};
self.send_ephemeral(ctx, tx).await
} else {
Err(ErrorCode::NoSpaceSessions.into())
}
}
async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> {
warn!("Sending Busy as all exchanges are occupied");
create_status_report(
tx,
GeneralCode::Busy,
rx.get_proto_id() as _,
if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
SCStatusCodes::Busy as _
} else {
IMStatusCode::Busy as _
},
None, // TODO: ms
)?;
let ctx = ExchangeCtx::prep_ephemeral(
SessionId::load(rx),
&mut self.session_mgr.borrow_mut(),
Some(rx),
tx,
)?;
self.send_ephemeral(ctx, tx).await
}
async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> {
let _guard = self.ephemeral_mutex.lock().await;
let notification = Notification::new();
let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) };
ctx.state = ExchangeState::Complete {
tx,
notification: &notification,
};
*self.ephemeral.borrow_mut() = Some(ctx);
self.send_notification.signal(());
notification.wait().await;
*self.ephemeral.borrow_mut() = None;
Ok(())
}
fn assign_exchange(
&self,
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
rx: &mut Packet<'_>,
) -> Result<(usize, bool), Error> {
// Get the session
let mut session_mgr = self.session_mgr.borrow_mut();
@ -693,8 +813,7 @@ impl<'a> Matter<'a> {
session.recv(self.epoch, rx)?;
// Get the exchange
// TODO: Handle out of space
let (exch, new) = Self::register(
let (exchange_index, new) = Self::register(
exchanges,
ExchangeId::load(rx),
Role::complementary(rx.proto.is_initiator()),
@ -703,32 +822,9 @@ impl<'a> Matter<'a> {
)?;
// Message Reliability Protocol
exch.mrp.recv(rx, self.epoch)?;
exchanges[exchange_index].mrp.recv(rx, self.epoch)?;
Ok((exch, new))
}
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
tx.proto.exch_id = ctx.id.id;
if ctx.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
ctx.mrp.pre_send(tx)?;
session_mgr.send(sess_index, tx)
Ok((exchange_index, new))
}
fn register(
@ -736,7 +832,7 @@ impl<'a> Matter<'a> {
id: ExchangeId,
role: Role,
create_new: bool,
) -> Result<(&mut ExchangeCtx, bool), Error> {
) -> Result<(usize, bool), Error> {
let exchange_index = exchanges
.iter_mut()
.enumerate()
@ -745,7 +841,7 @@ impl<'a> Matter<'a> {
if let Some(exchange_index) = exchange_index {
let exchange = &mut exchanges[exchange_index];
if exchange.role == role {
Ok((exchange, false))
Ok((exchange_index, false))
} else {
Err(ErrorCode::NoExchange.into())
}
@ -759,9 +855,11 @@ impl<'a> Matter<'a> {
state: ExchangeState::Active,
};
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?;
exchanges
.push(exchange)
.map_err(|_| ErrorCode::NoSpaceExchanges)?;
Ok((exchanges.iter_mut().next_back().unwrap(), true))
Ok((exchanges.len() - 1, true))
} else {
Err(ErrorCode::NoExchange.into())
}

View file

@ -1,7 +1,7 @@
use crate::{
acl::Accessor,
error::{Error, ErrorCode},
utils::select::Notification,
utils::{epoch::Epoch, select::Notification},
Matter,
};
@ -9,7 +9,7 @@ use super::{
mrp::ReliableMessage,
network::Address,
packet::Packet,
session::{Session, SessionMgr},
session::{CloneData, Session, SessionMgr},
};
pub const MAX_EXCHANGES: usize = 8;
@ -46,6 +46,101 @@ impl ExchangeCtx {
) -> Option<&'r mut ExchangeCtx> {
exchanges.iter_mut().find(|exchange| exchange.id == *id)
}
pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self {
Self {
id: ExchangeId {
id: if let Some(rx) = reply_to {
rx.proto.exch_id
} else {
0
},
session_id: session_id.clone(),
},
role: if reply_to.is_some() {
Role::Responder
} else {
Role::Initiator
},
mrp: ReliableMessage::new(),
state: ExchangeState::Active,
}
}
pub(crate) fn prep_ephemeral(
session_id: SessionId,
session_mgr: &mut SessionMgr,
reply_to: Option<&Packet<'_>>,
tx: &mut Packet<'_>,
) -> Result<ExchangeCtx, Error> {
let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to);
let sess_index = session_mgr.get(
session_id.id,
session_id.peer_addr,
session_id.peer_nodeid,
session_id.is_encrypted,
);
let epoch = session_mgr.epoch;
let rand = session_mgr.rand;
if let Some(rx) = reply_to {
ctx.mrp.recv(rx, epoch)?;
} else {
tx.proto.set_initiator();
}
tx.unset_reliable();
if let Some(sess_index) = sess_index {
let session = session_mgr.mut_by_index(sess_index).unwrap();
ctx.pre_send_sess(session, tx, epoch)?;
} else {
let mut session =
Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand);
ctx.pre_send_sess(&mut session, tx, epoch)?;
}
Ok(ctx)
}
pub(crate) fn pre_send(
&mut self,
session_mgr: &mut SessionMgr,
tx: &mut Packet,
) -> Result<(), Error> {
let epoch = session_mgr.epoch;
let sess_index = session_mgr
.get(
self.id.session_id.id,
self.id.session_id.peer_addr,
self.id.session_id.peer_nodeid,
self.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
self.pre_send_sess(session, tx, epoch)
}
pub(crate) fn pre_send_sess(
&mut self,
session: &mut Session,
tx: &mut Packet,
epoch: Epoch,
) -> Result<(), Error> {
tx.proto.exch_id = self.id.id;
if self.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
self.mrp.pre_send(tx)?;
session.send(epoch, tx)
}
}
#[derive(Debug, Clone)]
@ -192,15 +287,6 @@ impl<'a> Exchange<'a> {
self.with_session_mut(|sess| f(sess))
}
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
{
let mut session_mgr = self.matter.session_mgr.borrow_mut();
f(&mut session_mgr)
}
pub async fn acknowledge(&mut self) -> Result<(), Error> {
let wait = self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
@ -226,8 +312,12 @@ impl<'a> Exchange<'a> {
Ok(())
}
pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> {
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
pub async fn exchange(
&mut self,
tx: &mut Packet<'_>,
rx: &mut Packet<'_>,
) -> Result<(), Error> {
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
self.with_ctx_mut(|_self, ctx| {
@ -235,6 +325,9 @@ impl<'a> Exchange<'a> {
Err(ErrorCode::NoExchange)?;
}
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
ctx.pre_send(&mut session_mgr, tx)?;
ctx.state = ExchangeState::ExchangeSend {
tx: tx as *const _,
rx: rx as *mut _,
@ -250,18 +343,21 @@ impl<'a> Exchange<'a> {
Ok(())
}
pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> {
pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
self.send_complete(tx).await
}
pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> {
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
ctx.pre_send(&mut session_mgr, tx)?;
ctx.state = ExchangeState::Complete {
tx: tx as *const _,
notification: &_self.notification as *const _,
@ -276,6 +372,31 @@ impl<'a> Exchange<'a> {
Ok(())
}
pub(crate) fn get_next_sess_id(&mut self) -> u16 {
self.matter.session_mgr.borrow_mut().get_next_sess_id()
}
pub(crate) async fn clone_session(
&mut self,
tx: &mut Packet<'_>,
clone_data: &CloneData,
) -> Result<usize, Error> {
loop {
let result = self
.matter
.session_mgr
.borrow_mut()
.clone_session(clone_data);
match result {
Err(err) if err.code() == ErrorCode::NoSpaceSessions => {
self.matter.evict_session(tx).await?
}
other => break other,
}
}
}
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,

View file

@ -19,13 +19,13 @@ use crate::data_model::sdm::noc::NocData;
use crate::utils::epoch::Epoch;
use crate::utils::rand::Rand;
use core::fmt;
use core::ops::{Deref, DerefMut};
use core::time::Duration;
use crate::{error::*, transport::plain_hdr};
use log::info;
use super::dedup::RxCtrState;
use super::exchange::SessionId;
use super::{network::Address, packet::Packet};
pub const MAX_CAT_IDS_PER_NOC: usize = 3;
@ -151,6 +151,15 @@ impl Session {
}
}
pub fn id(&self) -> SessionId {
SessionId {
id: self.local_sess_id,
peer_addr: self.peer_addr,
peer_nodeid: self.peer_nodeid,
is_encrypted: self.is_encrypted(),
}
}
pub fn set_noc_data(&mut self, data: NocData) {
self.data = Some(data);
}
@ -251,7 +260,7 @@ impl Session {
Ok(())
}
fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
self.last_use = epoch();
tx.proto_encode(
@ -291,8 +300,8 @@ pub const MAX_SESSIONS: usize = 16;
pub struct SessionMgr {
next_sess_id: u16,
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
epoch: Epoch,
rand: Rand,
pub(crate) epoch: Epoch,
pub(crate) rand: Rand,
}
impl SessionMgr {
@ -327,7 +336,11 @@ impl SessionMgr {
}
// Ensure the currently selected id doesn't match any existing session
if self.get_with_id(next_sess_id).is_none() {
if self.sessions.iter().all(|sess| {
sess.as_ref()
.map(|sess| sess.get_local_sess_id() != next_sess_id)
.unwrap_or(true)
}) {
break;
}
}
@ -381,12 +394,12 @@ impl SessionMgr {
} else if self.sessions.len() < MAX_SESSIONS {
self.sessions
.push(Some(session))
.map_err(|_| ErrorCode::NoSpace)
.map_err(|_| ErrorCode::NoSpaceSessions)
.unwrap();
Ok(self.sessions.len() - 1)
} else {
Err(ErrorCode::NoSpace.into())
Err(ErrorCode::NoSpaceSessions.into())
}
}
@ -419,14 +432,6 @@ impl SessionMgr {
})
}
pub fn get_with_id(&mut self, sess_id: u16) -> Option<SessionHandle> {
let index = self
.sessions
.iter_mut()
.position(|x| x.as_ref().map(|s| s.local_sess_id) == Some(sess_id))?;
Some(self.get_session_handle(index))
}
pub fn get_or_add(
&mut self,
sess_id: u16,
@ -472,13 +477,6 @@ impl SessionMgr {
.ok_or(ErrorCode::NoSession)?
.send(self.epoch, tx)
}
pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle {
SessionHandle {
sess_mgr: self,
sess_idx,
}
}
}
impl fmt::Display for SessionMgr {
@ -492,45 +490,6 @@ impl fmt::Display for SessionMgr {
}
}
pub struct SessionHandle<'a> {
pub(crate) sess_mgr: &'a mut SessionMgr,
sess_idx: usize,
}
impl<'a> SessionHandle<'a> {
pub fn session(&self) -> &Session {
self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap()
}
pub fn session_mut(&mut self) -> &mut Session {
self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap()
}
pub fn reserve_new_sess_id(&mut self) -> u16 {
self.sess_mgr.get_next_sess_id()
}
pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> {
self.sess_mgr.send(self.sess_idx, tx)
}
}
impl<'a> Deref for SessionHandle<'a> {
type Target = Session;
fn deref(&self) -> &Self::Target {
// There is no other option but to panic if this is None
self.session()
}
}
impl<'a> DerefMut for SessionHandle<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
// There is no other option but to panic if this is None
self.session_mut()
}
}
#[cfg(test)]
mod tests {
@ -545,12 +504,12 @@ mod tests {
fn test_next_sess_id_doesnt_reuse() {
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx);
let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(1);
assert_eq!(sm.get_next_sess_id(), 2);
assert_eq!(sm.get_next_sess_id(), 3);
let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx);
let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(4);
assert_eq!(sm.get_next_sess_id(), 5);
}
@ -559,7 +518,7 @@ mod tests {
fn test_next_sess_id_overflows() {
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx);
let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(1);
assert_eq!(sm.get_next_sess_id(), 2);
sm.next_sess_id = 65534;