Handle out of sessions and out of exchanges

This commit is contained in:
ivmarkov 2023-07-31 17:34:21 +00:00
parent 4c347c0c0b
commit e171e33510
8 changed files with 360 additions and 171 deletions

View file

@ -17,6 +17,8 @@
use core::{borrow::Borrow, cell::RefCell}; use core::{borrow::Borrow, cell::RefCell};
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
use crate::{ use crate::{
acl::AclMgr, acl::AclMgr,
data_model::{ data_model::{
@ -61,6 +63,8 @@ pub struct Matter<'a> {
dev_att: &'a dyn DevAttDataFetcher, dev_att: &'a dyn DevAttDataFetcher,
pub(crate) port: u16, pub(crate) port: u16,
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>, pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
pub(crate) ephemeral: RefCell<Option<ExchangeCtx>>,
pub(crate) ephemeral_mutex: Mutex<NoopRawMutex, ()>,
pub session_mgr: RefCell<SessionMgr>, // Public for tests pub session_mgr: RefCell<SessionMgr>, // Public for tests
} }
@ -108,6 +112,8 @@ impl<'a> Matter<'a> {
dev_att, dev_att,
port, port,
exchanges: RefCell::new(heapless::Vec::new()), exchanges: RefCell::new(heapless::Vec::new()),
ephemeral: RefCell::new(None),
ephemeral_mutex: Mutex::new(()),
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
} }
} }

View file

@ -47,6 +47,8 @@ pub enum ErrorCode {
NoMemory, NoMemory,
NoSession, NoSession,
NoSpace, NoSpace,
NoSpaceExchanges,
NoSpaceSessions,
NoSpaceAckTable, NoSpaceAckTable,
NoSpaceRetransTable, NoSpaceRetransTable,
NoTagFound, NoTagFound,

View file

@ -96,7 +96,7 @@ impl<'a> Case<'a> {
) -> Result<(), Error> { ) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma3 as _)?; rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
let status = { let result = {
let fabric_mgr = self.fabric_mgr.borrow(); let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
@ -133,7 +133,7 @@ impl<'a> Case<'a> {
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
error!("Certificate Chain doesn't match: {}", e); error!("Certificate Chain doesn't match: {}", e);
SCStatusCodes::InvalidParameter Err(SCStatusCodes::InvalidParameter)
} else if let Err(e) = Case::validate_sigma3_sign( } else if let Err(e) = Case::validate_sigma3_sign(
d.initiator_noc.0, d.initiator_noc.0,
d.initiator_icac.map(|a| a.0), d.initiator_icac.map(|a| a.0),
@ -142,32 +142,35 @@ impl<'a> Case<'a> {
case_session, case_session,
) { ) {
error!("Sigma3 Signature doesn't match: {}", e); error!("Sigma3 Signature doesn't match: {}", e);
SCStatusCodes::InvalidParameter Err(SCStatusCodes::InvalidParameter)
} else { } else {
// Only now do we add this message to the TT Hash // Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default(); let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids); initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(rx.as_slice())?; case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data(
Ok(Case::get_session_clone_data(
fabric.ipk.op_key(), fabric.ipk.op_key(),
fabric.get_node_id(), fabric.get_node_id(),
initiator_noc.get_node_id()?, initiator_noc.get_node_id()?,
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
case_session, case_session,
&peer_catids, &peer_catids,
)?; )?)
// TODO: Handle NoSpace
exchange
.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
SCStatusCodes::SessionEstablishmentSuccess
} }
} else { } else {
SCStatusCodes::NoSharedTrustRoots Err(SCStatusCodes::NoSharedTrustRoots)
} }
}; };
let status = match result {
Ok(clone_data) => {
exchange.clone_session(tx, &clone_data).await?;
SCStatusCodes::SessionEstablishmentSuccess
}
Err(status) => status,
};
complete_with_status(exchange, tx, status, None).await complete_with_status(exchange, tx, status, None).await
} }
@ -201,7 +204,7 @@ impl<'a> Case<'a> {
return Ok(()); return Ok(());
} }
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; let local_sessid = exchange.get_next_sess_id();
case_session.peer_sessid = r.initiator_sessid; case_session.peer_sessid = r.initiator_sessid;
case_session.local_sessid = local_sessid; case_session.local_sessid = local_sessid;
case_session.tt_hash.update(rx_buf)?; case_session.tt_hash.update(rx_buf)?;

View file

@ -78,8 +78,8 @@ pub fn create_sc_status_report(
// the session will be closed soon // the session will be closed soon
GeneralCode::Success GeneralCode::Success
} }
SCStatusCodes::Busy SCStatusCodes::Busy => GeneralCode::Busy,
| SCStatusCodes::InvalidParameter SCStatusCodes::InvalidParameter
| SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::NoSharedTrustRoots
| SCStatusCodes::SessionNotFound => GeneralCode::Failure, | SCStatusCodes::SessionNotFound => GeneralCode::Failure,
}; };

View file

@ -167,9 +167,9 @@ impl<'a> Pake<'a> {
self.update_timeout(exchange, tx, true).await?; self.update_timeout(exchange, tx, true).await?;
let cA = extract_pasepake_1_or_3_params(rx.as_slice())?; let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let (status_code, ke) = spake2p.handle_cA(cA); let (status, ke) = spake2p.handle_cA(cA);
let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { let result = if status == SCStatusCodes::SessionEstablishmentSuccess {
// Get the keys // Get the keys
let ke = ke.ok_or(ErrorCode::Invalid)?; let ke = ke.ok_or(ErrorCode::Invalid)?;
let mut session_keys: [u8; 48] = [0; 48]; let mut session_keys: [u8; 48] = [0; 48];
@ -194,22 +194,22 @@ impl<'a> Pake<'a> {
.att_challenge .att_challenge
.copy_from_slice(&session_keys[32..48]); .copy_from_slice(&session_keys[32..48]);
// Queue a transport mgr request to add a new session Ok(clone_data)
Some(clone_data)
} else { } else {
None Err(status)
}; };
if let Some(clone_data) = clone_data { let status = match result {
// TODO: Handle NoSpace Ok(clone_data) => {
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; exchange.clone_session(tx, &clone_data).await?;
self.pase.borrow_mut().disable_pase_session(mdns)?;
self.pase.borrow_mut().disable_pase_session(mdns)?; SCStatusCodes::SessionEstablishmentSuccess
} }
Err(status) => status,
};
complete_with_status(exchange, tx, status_code, None).await?; complete_with_status(exchange, tx, status, None).await
Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -273,7 +273,7 @@ impl<'a> Pake<'a> {
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
(self.pase.borrow().rand)(&mut our_random); (self.pase.borrow().rand)(&mut our_random);
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; let local_sessid = exchange.get_next_sess_id();
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
spake2p.set_app_data(spake2p_data); spake2p.set_app_data(spake2p_data);

View file

@ -25,6 +25,9 @@ use embassy_time::{Duration, Timer};
use log::{error, info, warn}; use log::{error, info, warn};
use crate::interaction_model::core::IMStatusCode;
use crate::secure_channel::common::SCStatusCodes;
use crate::secure_channel::status_report::{create_status_report, GeneralCode};
use crate::utils::select::Notification; use crate::utils::select::Notification;
use crate::CommissioningData; use crate::CommissioningData;
use crate::{ use crate::{
@ -41,6 +44,7 @@ use crate::{
Matter, Matter,
}; };
use super::exchange::SessionId;
use super::{ use super::{
exchange::{ exchange::{
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES, Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
@ -97,7 +101,7 @@ impl RunBuffers {
pub struct PacketBuffers { pub struct PacketBuffers {
tx: [TxBuf; MAX_EXCHANGES], tx: [TxBuf; MAX_EXCHANGES],
rx: [RxBuf; MAX_EXCHANGES], rx: [RxBuf; MAX_EXCHANGES],
sx: [SxBuf; MAX_EXCHANGES], sx: [SxBuf; MAX_EXCHANGES + 1],
} }
impl PacketBuffers { impl PacketBuffers {
@ -107,7 +111,7 @@ impl PacketBuffers {
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1];
#[inline(always)] #[inline(always)]
pub const fn new() -> Self { pub const fn new() -> Self {
@ -266,7 +270,12 @@ impl<'a> Matter<'a> {
.unwrap(); .unwrap();
} }
let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel)); let mut rx = pin!(self.handle_rx_multiplex(
rx_pipe,
unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() },
construction_notification,
&channel,
));
let result = select(&mut rx, select_slice(&mut handlers)).await; let result = select(&mut rx, select_slice(&mut handlers)).await;
@ -291,7 +300,7 @@ impl<'a> Matter<'a> {
if data.chunk.is_none() { if data.chunk.is_none() {
let mut tx = alloc!(Packet::new_tx(data.buf)); let mut tx = alloc!(Packet::new_tx(data.buf));
if self.pull_tx(&mut tx).await? { if self.pull_tx(&mut tx)? {
data.chunk = Some(Chunk { data.chunk = Some(Chunk {
start: tx.get_writebuf()?.get_start(), start: tx.get_writebuf()?.get_start(),
end: tx.get_writebuf()?.get_tail(), end: tx.get_writebuf()?.get_tail(),
@ -315,12 +324,15 @@ impl<'a> Matter<'a> {
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>( pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
&'t self, &'t self,
rx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>,
sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
construction_notification: &'e Notification, construction_notification: &'e Notification,
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>, channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
) -> Result<(), Error> ) -> Result<(), Error>
where where
't: 'e, 't: 'e,
{ {
let mut sts_tx = alloc!(Packet::new_tx(sts_buf));
loop { loop {
info!("Transport: waiting for incoming packets"); info!("Transport: waiting for incoming packets");
@ -331,8 +343,9 @@ impl<'a> Matter<'a> {
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
rx.peer = chunk.addr; rx.peer = chunk.addr;
if let Some(exchange_ctr) = if let Some(exchange_ctr) = self
self.process_rx(construction_notification, &mut rx)? .process_rx(construction_notification, &mut rx, &mut sts_tx)
.await?
{ {
let exchange_id = exchange_ctr.id().clone(); let exchange_id = exchange_ctr.id().clone();
@ -444,24 +457,39 @@ impl<'a> Matter<'a> {
self.session_mgr.borrow_mut().reset(); self.session_mgr.borrow_mut().reset();
} }
pub fn process_rx<'r>( pub async fn process_rx<'r>(
&'r self, &'r self,
construction_notification: &'r Notification, construction_notification: &'r Notification,
src_rx: &mut Packet<'_>, src_rx: &mut Packet<'_>,
sts_tx: &mut Packet<'_>,
) -> Result<Option<ExchangeCtr<'r>>, Error> { ) -> Result<Option<ExchangeCtr<'r>>, Error> {
src_rx.plain_hdr_decode()?;
self.purge()?; self.purge()?;
let (exchange_index, new) = loop {
let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx);
match result {
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
// TODO: NoSession, NoExchange and others
ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?,
ErrorCode::NoSpaceExchanges => {
self.send_busy(src_rx, sts_tx).await?;
return Ok(None);
}
_ => break Err(e),
},
other => break other,
}
}?;
let mut exchanges = self.exchanges.borrow_mut(); let mut exchanges = self.exchanges.borrow_mut();
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) { let ctx = &mut exchanges[exchange_index];
Ok((ctx, new)) => (ctx, new),
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
_ => Err(e)?,
},
};
src_rx.log("Got packet"); src_rx.log("Got packet");
@ -516,6 +544,8 @@ impl<'a> Matter<'a> {
ExchangeState::ExchangeRecv { ExchangeState::ExchangeRecv {
rx, notification, .. rx, notification, ..
} => { } => {
// TODO: Handle Busy status codes
let rx = unsafe { rx.as_mut() }.unwrap(); let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?; rx.load(src_rx)?;
@ -572,12 +602,24 @@ impl<'a> Matter<'a> {
Ok(()) Ok(())
} }
pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> { pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result<bool, Error> {
self.purge()?; self.purge()?;
let mut ephemeral = self.ephemeral.borrow_mut();
let mut exchanges = self.exchanges.borrow_mut(); let mut exchanges = self.exchanges.borrow_mut();
let ctx = exchanges.iter_mut().find(|ctx| { self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx)
}
fn pull_tx_exchanges<'i, I>(
&self,
mut exchanges: I,
dest_tx: &mut Packet,
) -> Result<bool, Error>
where
I: Iterator<Item = &'i mut ExchangeCtx>,
{
let ctx = exchanges.find(|ctx| {
matches!( matches!(
&ctx.state, &ctx.state,
ExchangeState::Acknowledge { .. } ExchangeState::Acknowledge { .. }
@ -629,10 +671,15 @@ impl<'a> Matter<'a> {
let tx = unsafe { tx.as_ref() }.unwrap(); let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?; dest_tx.load(tx)?;
*state = ExchangeState::CompleteAcknowledge { if dest_tx.is_reliable() {
_tx: tx as *const _, *state = ExchangeState::CompleteAcknowledge {
notification: *notification, _tx: tx as *const _,
}; notification: *notification,
};
} else {
unsafe { notification.as_ref() }.unwrap().signal(());
ctx.state = ExchangeState::Closed;
}
true true
} }
@ -648,8 +695,6 @@ impl<'a> Matter<'a> {
if send { if send {
dest_tx.log("Sending packet"); dest_tx.log("Sending packet");
self.pre_send(ctx, dest_tx)?;
self.notify_changed(); self.notify_changed();
return Ok(true); return Ok(true);
@ -675,13 +720,88 @@ impl<'a> Matter<'a> {
Ok(()) Ok(())
} }
fn post_recv<'r>( pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
&self, let sess_index = self.session_mgr.borrow().get_session_for_eviction();
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>, if let Some(sess_index) = sess_index {
rx: &mut Packet<'_>, let ctx = {
) -> Result<(&'r mut ExchangeCtx, bool), Error> { create_status_report(
rx.plain_hdr_decode()?; tx,
GeneralCode::Success,
PROTO_ID_SECURE_CHANNEL as _,
SCStatusCodes::CloseSession as _,
None,
)?;
let mut session_mgr = self.session_mgr.borrow_mut();
let session_id = session_mgr.mut_by_index(sess_index).unwrap().id();
warn!("Evicting session: {:?}", session_id);
let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?;
session_mgr.remove(sess_index);
ctx
};
self.send_ephemeral(ctx, tx).await
} else {
Err(ErrorCode::NoSpaceSessions.into())
}
}
async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> {
warn!("Sending Busy as all exchanges are occupied");
create_status_report(
tx,
GeneralCode::Busy,
rx.get_proto_id() as _,
if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
SCStatusCodes::Busy as _
} else {
IMStatusCode::Busy as _
},
None, // TODO: ms
)?;
let ctx = ExchangeCtx::prep_ephemeral(
SessionId::load(rx),
&mut self.session_mgr.borrow_mut(),
Some(rx),
tx,
)?;
self.send_ephemeral(ctx, tx).await
}
async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> {
let _guard = self.ephemeral_mutex.lock().await;
let notification = Notification::new();
let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) };
ctx.state = ExchangeState::Complete {
tx,
notification: &notification,
};
*self.ephemeral.borrow_mut() = Some(ctx);
self.send_notification.signal(());
notification.wait().await;
*self.ephemeral.borrow_mut() = None;
Ok(())
}
fn assign_exchange(
&self,
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
rx: &mut Packet<'_>,
) -> Result<(usize, bool), Error> {
// Get the session // Get the session
let mut session_mgr = self.session_mgr.borrow_mut(); let mut session_mgr = self.session_mgr.borrow_mut();
@ -693,8 +813,7 @@ impl<'a> Matter<'a> {
session.recv(self.epoch, rx)?; session.recv(self.epoch, rx)?;
// Get the exchange // Get the exchange
// TODO: Handle out of space let (exchange_index, new) = Self::register(
let (exch, new) = Self::register(
exchanges, exchanges,
ExchangeId::load(rx), ExchangeId::load(rx),
Role::complementary(rx.proto.is_initiator()), Role::complementary(rx.proto.is_initiator()),
@ -703,32 +822,9 @@ impl<'a> Matter<'a> {
)?; )?;
// Message Reliability Protocol // Message Reliability Protocol
exch.mrp.recv(rx, self.epoch)?; exchanges[exchange_index].mrp.recv(rx, self.epoch)?;
Ok((exch, new)) Ok((exchange_index, new))
}
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
tx.proto.exch_id = ctx.id.id;
if ctx.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
ctx.mrp.pre_send(tx)?;
session_mgr.send(sess_index, tx)
} }
fn register( fn register(
@ -736,7 +832,7 @@ impl<'a> Matter<'a> {
id: ExchangeId, id: ExchangeId,
role: Role, role: Role,
create_new: bool, create_new: bool,
) -> Result<(&mut ExchangeCtx, bool), Error> { ) -> Result<(usize, bool), Error> {
let exchange_index = exchanges let exchange_index = exchanges
.iter_mut() .iter_mut()
.enumerate() .enumerate()
@ -745,7 +841,7 @@ impl<'a> Matter<'a> {
if let Some(exchange_index) = exchange_index { if let Some(exchange_index) = exchange_index {
let exchange = &mut exchanges[exchange_index]; let exchange = &mut exchanges[exchange_index];
if exchange.role == role { if exchange.role == role {
Ok((exchange, false)) Ok((exchange_index, false))
} else { } else {
Err(ErrorCode::NoExchange.into()) Err(ErrorCode::NoExchange.into())
} }
@ -759,9 +855,11 @@ impl<'a> Matter<'a> {
state: ExchangeState::Active, state: ExchangeState::Active,
}; };
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?; exchanges
.push(exchange)
.map_err(|_| ErrorCode::NoSpaceExchanges)?;
Ok((exchanges.iter_mut().next_back().unwrap(), true)) Ok((exchanges.len() - 1, true))
} else { } else {
Err(ErrorCode::NoExchange.into()) Err(ErrorCode::NoExchange.into())
} }

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
acl::Accessor, acl::Accessor,
error::{Error, ErrorCode}, error::{Error, ErrorCode},
utils::select::Notification, utils::{epoch::Epoch, select::Notification},
Matter, Matter,
}; };
@ -9,7 +9,7 @@ use super::{
mrp::ReliableMessage, mrp::ReliableMessage,
network::Address, network::Address,
packet::Packet, packet::Packet,
session::{Session, SessionMgr}, session::{CloneData, Session, SessionMgr},
}; };
pub const MAX_EXCHANGES: usize = 8; pub const MAX_EXCHANGES: usize = 8;
@ -46,6 +46,101 @@ impl ExchangeCtx {
) -> Option<&'r mut ExchangeCtx> { ) -> Option<&'r mut ExchangeCtx> {
exchanges.iter_mut().find(|exchange| exchange.id == *id) exchanges.iter_mut().find(|exchange| exchange.id == *id)
} }
pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self {
Self {
id: ExchangeId {
id: if let Some(rx) = reply_to {
rx.proto.exch_id
} else {
0
},
session_id: session_id.clone(),
},
role: if reply_to.is_some() {
Role::Responder
} else {
Role::Initiator
},
mrp: ReliableMessage::new(),
state: ExchangeState::Active,
}
}
pub(crate) fn prep_ephemeral(
session_id: SessionId,
session_mgr: &mut SessionMgr,
reply_to: Option<&Packet<'_>>,
tx: &mut Packet<'_>,
) -> Result<ExchangeCtx, Error> {
let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to);
let sess_index = session_mgr.get(
session_id.id,
session_id.peer_addr,
session_id.peer_nodeid,
session_id.is_encrypted,
);
let epoch = session_mgr.epoch;
let rand = session_mgr.rand;
if let Some(rx) = reply_to {
ctx.mrp.recv(rx, epoch)?;
} else {
tx.proto.set_initiator();
}
tx.unset_reliable();
if let Some(sess_index) = sess_index {
let session = session_mgr.mut_by_index(sess_index).unwrap();
ctx.pre_send_sess(session, tx, epoch)?;
} else {
let mut session =
Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand);
ctx.pre_send_sess(&mut session, tx, epoch)?;
}
Ok(ctx)
}
pub(crate) fn pre_send(
&mut self,
session_mgr: &mut SessionMgr,
tx: &mut Packet,
) -> Result<(), Error> {
let epoch = session_mgr.epoch;
let sess_index = session_mgr
.get(
self.id.session_id.id,
self.id.session_id.peer_addr,
self.id.session_id.peer_nodeid,
self.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
self.pre_send_sess(session, tx, epoch)
}
pub(crate) fn pre_send_sess(
&mut self,
session: &mut Session,
tx: &mut Packet,
epoch: Epoch,
) -> Result<(), Error> {
tx.proto.exch_id = self.id.id;
if self.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
self.mrp.pre_send(tx)?;
session.send(epoch, tx)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -192,15 +287,6 @@ impl<'a> Exchange<'a> {
self.with_session_mut(|sess| f(sess)) self.with_session_mut(|sess| f(sess))
} }
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
{
let mut session_mgr = self.matter.session_mgr.borrow_mut();
f(&mut session_mgr)
}
pub async fn acknowledge(&mut self) -> Result<(), Error> { pub async fn acknowledge(&mut self) -> Result<(), Error> {
let wait = self.with_ctx_mut(|_self, ctx| { let wait = self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) { if !matches!(ctx.state, ExchangeState::Active) {
@ -226,8 +312,12 @@ impl<'a> Exchange<'a> {
Ok(()) Ok(())
} }
pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> { pub async fn exchange(
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; &mut self,
tx: &mut Packet<'_>,
rx: &mut Packet<'_>,
) -> Result<(), Error> {
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) }; let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
self.with_ctx_mut(|_self, ctx| { self.with_ctx_mut(|_self, ctx| {
@ -235,6 +325,9 @@ impl<'a> Exchange<'a> {
Err(ErrorCode::NoExchange)?; Err(ErrorCode::NoExchange)?;
} }
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
ctx.pre_send(&mut session_mgr, tx)?;
ctx.state = ExchangeState::ExchangeSend { ctx.state = ExchangeState::ExchangeSend {
tx: tx as *const _, tx: tx as *const _,
rx: rx as *mut _, rx: rx as *mut _,
@ -250,18 +343,21 @@ impl<'a> Exchange<'a> {
Ok(()) Ok(())
} }
pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> { pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
self.send_complete(tx).await self.send_complete(tx).await
} }
pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> { pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
self.with_ctx_mut(|_self, ctx| { self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) { if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?; Err(ErrorCode::NoExchange)?;
} }
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
ctx.pre_send(&mut session_mgr, tx)?;
ctx.state = ExchangeState::Complete { ctx.state = ExchangeState::Complete {
tx: tx as *const _, tx: tx as *const _,
notification: &_self.notification as *const _, notification: &_self.notification as *const _,
@ -276,6 +372,31 @@ impl<'a> Exchange<'a> {
Ok(()) Ok(())
} }
pub(crate) fn get_next_sess_id(&mut self) -> u16 {
self.matter.session_mgr.borrow_mut().get_next_sess_id()
}
pub(crate) async fn clone_session(
&mut self,
tx: &mut Packet<'_>,
clone_data: &CloneData,
) -> Result<usize, Error> {
loop {
let result = self
.matter
.session_mgr
.borrow_mut()
.clone_session(clone_data);
match result {
Err(err) if err.code() == ErrorCode::NoSpaceSessions => {
self.matter.evict_session(tx).await?
}
other => break other,
}
}
}
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error> fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
where where
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>, F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,

View file

@ -19,13 +19,13 @@ use crate::data_model::sdm::noc::NocData;
use crate::utils::epoch::Epoch; use crate::utils::epoch::Epoch;
use crate::utils::rand::Rand; use crate::utils::rand::Rand;
use core::fmt; use core::fmt;
use core::ops::{Deref, DerefMut};
use core::time::Duration; use core::time::Duration;
use crate::{error::*, transport::plain_hdr}; use crate::{error::*, transport::plain_hdr};
use log::info; use log::info;
use super::dedup::RxCtrState; use super::dedup::RxCtrState;
use super::exchange::SessionId;
use super::{network::Address, packet::Packet}; use super::{network::Address, packet::Packet};
pub const MAX_CAT_IDS_PER_NOC: usize = 3; pub const MAX_CAT_IDS_PER_NOC: usize = 3;
@ -151,6 +151,15 @@ impl Session {
} }
} }
pub fn id(&self) -> SessionId {
SessionId {
id: self.local_sess_id,
peer_addr: self.peer_addr,
peer_nodeid: self.peer_nodeid,
is_encrypted: self.is_encrypted(),
}
}
pub fn set_noc_data(&mut self, data: NocData) { pub fn set_noc_data(&mut self, data: NocData) {
self.data = Some(data); self.data = Some(data);
} }
@ -251,7 +260,7 @@ impl Session {
Ok(()) Ok(())
} }
fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
self.last_use = epoch(); self.last_use = epoch();
tx.proto_encode( tx.proto_encode(
@ -291,8 +300,8 @@ pub const MAX_SESSIONS: usize = 16;
pub struct SessionMgr { pub struct SessionMgr {
next_sess_id: u16, next_sess_id: u16,
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>, sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
epoch: Epoch, pub(crate) epoch: Epoch,
rand: Rand, pub(crate) rand: Rand,
} }
impl SessionMgr { impl SessionMgr {
@ -327,7 +336,11 @@ impl SessionMgr {
} }
// Ensure the currently selected id doesn't match any existing session // Ensure the currently selected id doesn't match any existing session
if self.get_with_id(next_sess_id).is_none() { if self.sessions.iter().all(|sess| {
sess.as_ref()
.map(|sess| sess.get_local_sess_id() != next_sess_id)
.unwrap_or(true)
}) {
break; break;
} }
} }
@ -381,12 +394,12 @@ impl SessionMgr {
} else if self.sessions.len() < MAX_SESSIONS { } else if self.sessions.len() < MAX_SESSIONS {
self.sessions self.sessions
.push(Some(session)) .push(Some(session))
.map_err(|_| ErrorCode::NoSpace) .map_err(|_| ErrorCode::NoSpaceSessions)
.unwrap(); .unwrap();
Ok(self.sessions.len() - 1) Ok(self.sessions.len() - 1)
} else { } else {
Err(ErrorCode::NoSpace.into()) Err(ErrorCode::NoSpaceSessions.into())
} }
} }
@ -419,14 +432,6 @@ impl SessionMgr {
}) })
} }
pub fn get_with_id(&mut self, sess_id: u16) -> Option<SessionHandle> {
let index = self
.sessions
.iter_mut()
.position(|x| x.as_ref().map(|s| s.local_sess_id) == Some(sess_id))?;
Some(self.get_session_handle(index))
}
pub fn get_or_add( pub fn get_or_add(
&mut self, &mut self,
sess_id: u16, sess_id: u16,
@ -472,13 +477,6 @@ impl SessionMgr {
.ok_or(ErrorCode::NoSession)? .ok_or(ErrorCode::NoSession)?
.send(self.epoch, tx) .send(self.epoch, tx)
} }
pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle {
SessionHandle {
sess_mgr: self,
sess_idx,
}
}
} }
impl fmt::Display for SessionMgr { impl fmt::Display for SessionMgr {
@ -492,45 +490,6 @@ impl fmt::Display for SessionMgr {
} }
} }
pub struct SessionHandle<'a> {
pub(crate) sess_mgr: &'a mut SessionMgr,
sess_idx: usize,
}
impl<'a> SessionHandle<'a> {
pub fn session(&self) -> &Session {
self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap()
}
pub fn session_mut(&mut self) -> &mut Session {
self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap()
}
pub fn reserve_new_sess_id(&mut self) -> u16 {
self.sess_mgr.get_next_sess_id()
}
pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> {
self.sess_mgr.send(self.sess_idx, tx)
}
}
impl<'a> Deref for SessionHandle<'a> {
type Target = Session;
fn deref(&self) -> &Self::Target {
// There is no other option but to panic if this is None
self.session()
}
}
impl<'a> DerefMut for SessionHandle<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
// There is no other option but to panic if this is None
self.session_mut()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -545,12 +504,12 @@ mod tests {
fn test_next_sess_id_doesnt_reuse() { fn test_next_sess_id_doesnt_reuse() {
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
let sess_idx = sm.add(Address::default(), None).unwrap(); let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx); let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(1); sess.set_local_sess_id(1);
assert_eq!(sm.get_next_sess_id(), 2); assert_eq!(sm.get_next_sess_id(), 2);
assert_eq!(sm.get_next_sess_id(), 3); assert_eq!(sm.get_next_sess_id(), 3);
let sess_idx = sm.add(Address::default(), None).unwrap(); let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx); let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(4); sess.set_local_sess_id(4);
assert_eq!(sm.get_next_sess_id(), 5); assert_eq!(sm.get_next_sess_id(), 5);
} }
@ -559,7 +518,7 @@ mod tests {
fn test_next_sess_id_overflows() { fn test_next_sess_id_overflows() {
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
let sess_idx = sm.add(Address::default(), None).unwrap(); let sess_idx = sm.add(Address::default(), None).unwrap();
let mut sess = sm.get_session_handle(sess_idx); let sess = sm.mut_by_index(sess_idx).unwrap();
sess.set_local_sess_id(1); sess.set_local_sess_id(1);
assert_eq!(sm.get_next_sess_id(), 2); assert_eq!(sm.get_next_sess_id(), 2);
sm.next_sess_id = 65534; sm.next_sess_id = 65534;