Handle out of sessions and out of exchanges
This commit is contained in:
parent
4c347c0c0b
commit
e171e33510
8 changed files with 360 additions and 171 deletions
|
@ -17,6 +17,8 @@
|
||||||
|
|
||||||
use core::{borrow::Borrow, cell::RefCell};
|
use core::{borrow::Borrow, cell::RefCell};
|
||||||
|
|
||||||
|
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
acl::AclMgr,
|
acl::AclMgr,
|
||||||
data_model::{
|
data_model::{
|
||||||
|
@ -61,6 +63,8 @@ pub struct Matter<'a> {
|
||||||
dev_att: &'a dyn DevAttDataFetcher,
|
dev_att: &'a dyn DevAttDataFetcher,
|
||||||
pub(crate) port: u16,
|
pub(crate) port: u16,
|
||||||
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
|
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
|
||||||
|
pub(crate) ephemeral: RefCell<Option<ExchangeCtx>>,
|
||||||
|
pub(crate) ephemeral_mutex: Mutex<NoopRawMutex, ()>,
|
||||||
pub session_mgr: RefCell<SessionMgr>, // Public for tests
|
pub session_mgr: RefCell<SessionMgr>, // Public for tests
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,6 +112,8 @@ impl<'a> Matter<'a> {
|
||||||
dev_att,
|
dev_att,
|
||||||
port,
|
port,
|
||||||
exchanges: RefCell::new(heapless::Vec::new()),
|
exchanges: RefCell::new(heapless::Vec::new()),
|
||||||
|
ephemeral: RefCell::new(None),
|
||||||
|
ephemeral_mutex: Mutex::new(()),
|
||||||
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
|
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,8 @@ pub enum ErrorCode {
|
||||||
NoMemory,
|
NoMemory,
|
||||||
NoSession,
|
NoSession,
|
||||||
NoSpace,
|
NoSpace,
|
||||||
|
NoSpaceExchanges,
|
||||||
|
NoSpaceSessions,
|
||||||
NoSpaceAckTable,
|
NoSpaceAckTable,
|
||||||
NoSpaceRetransTable,
|
NoSpaceRetransTable,
|
||||||
NoTagFound,
|
NoTagFound,
|
||||||
|
|
|
@ -96,7 +96,7 @@ impl<'a> Case<'a> {
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
|
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
|
||||||
|
|
||||||
let status = {
|
let result = {
|
||||||
let fabric_mgr = self.fabric_mgr.borrow();
|
let fabric_mgr = self.fabric_mgr.borrow();
|
||||||
|
|
||||||
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
|
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
|
||||||
|
@ -133,7 +133,7 @@ impl<'a> Case<'a> {
|
||||||
|
|
||||||
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
|
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
|
||||||
error!("Certificate Chain doesn't match: {}", e);
|
error!("Certificate Chain doesn't match: {}", e);
|
||||||
SCStatusCodes::InvalidParameter
|
Err(SCStatusCodes::InvalidParameter)
|
||||||
} else if let Err(e) = Case::validate_sigma3_sign(
|
} else if let Err(e) = Case::validate_sigma3_sign(
|
||||||
d.initiator_noc.0,
|
d.initiator_noc.0,
|
||||||
d.initiator_icac.map(|a| a.0),
|
d.initiator_icac.map(|a| a.0),
|
||||||
|
@ -142,32 +142,35 @@ impl<'a> Case<'a> {
|
||||||
case_session,
|
case_session,
|
||||||
) {
|
) {
|
||||||
error!("Sigma3 Signature doesn't match: {}", e);
|
error!("Sigma3 Signature doesn't match: {}", e);
|
||||||
SCStatusCodes::InvalidParameter
|
Err(SCStatusCodes::InvalidParameter)
|
||||||
} else {
|
} else {
|
||||||
// Only now do we add this message to the TT Hash
|
// Only now do we add this message to the TT Hash
|
||||||
let mut peer_catids: NocCatIds = Default::default();
|
let mut peer_catids: NocCatIds = Default::default();
|
||||||
initiator_noc.get_cat_ids(&mut peer_catids);
|
initiator_noc.get_cat_ids(&mut peer_catids);
|
||||||
case_session.tt_hash.update(rx.as_slice())?;
|
case_session.tt_hash.update(rx.as_slice())?;
|
||||||
let clone_data = Case::get_session_clone_data(
|
|
||||||
|
Ok(Case::get_session_clone_data(
|
||||||
fabric.ipk.op_key(),
|
fabric.ipk.op_key(),
|
||||||
fabric.get_node_id(),
|
fabric.get_node_id(),
|
||||||
initiator_noc.get_node_id()?,
|
initiator_noc.get_node_id()?,
|
||||||
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
|
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
|
||||||
case_session,
|
case_session,
|
||||||
&peer_catids,
|
&peer_catids,
|
||||||
)?;
|
)?)
|
||||||
|
|
||||||
// TODO: Handle NoSpace
|
|
||||||
exchange
|
|
||||||
.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
|
|
||||||
|
|
||||||
SCStatusCodes::SessionEstablishmentSuccess
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
SCStatusCodes::NoSharedTrustRoots
|
Err(SCStatusCodes::NoSharedTrustRoots)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let status = match result {
|
||||||
|
Ok(clone_data) => {
|
||||||
|
exchange.clone_session(tx, &clone_data).await?;
|
||||||
|
SCStatusCodes::SessionEstablishmentSuccess
|
||||||
|
}
|
||||||
|
Err(status) => status,
|
||||||
|
};
|
||||||
|
|
||||||
complete_with_status(exchange, tx, status, None).await
|
complete_with_status(exchange, tx, status, None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,7 +204,7 @@ impl<'a> Case<'a> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
|
let local_sessid = exchange.get_next_sess_id();
|
||||||
case_session.peer_sessid = r.initiator_sessid;
|
case_session.peer_sessid = r.initiator_sessid;
|
||||||
case_session.local_sessid = local_sessid;
|
case_session.local_sessid = local_sessid;
|
||||||
case_session.tt_hash.update(rx_buf)?;
|
case_session.tt_hash.update(rx_buf)?;
|
||||||
|
|
|
@ -78,8 +78,8 @@ pub fn create_sc_status_report(
|
||||||
// the session will be closed soon
|
// the session will be closed soon
|
||||||
GeneralCode::Success
|
GeneralCode::Success
|
||||||
}
|
}
|
||||||
SCStatusCodes::Busy
|
SCStatusCodes::Busy => GeneralCode::Busy,
|
||||||
| SCStatusCodes::InvalidParameter
|
SCStatusCodes::InvalidParameter
|
||||||
| SCStatusCodes::NoSharedTrustRoots
|
| SCStatusCodes::NoSharedTrustRoots
|
||||||
| SCStatusCodes::SessionNotFound => GeneralCode::Failure,
|
| SCStatusCodes::SessionNotFound => GeneralCode::Failure,
|
||||||
};
|
};
|
||||||
|
|
|
@ -167,9 +167,9 @@ impl<'a> Pake<'a> {
|
||||||
self.update_timeout(exchange, tx, true).await?;
|
self.update_timeout(exchange, tx, true).await?;
|
||||||
|
|
||||||
let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
|
let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
|
||||||
let (status_code, ke) = spake2p.handle_cA(cA);
|
let (status, ke) = spake2p.handle_cA(cA);
|
||||||
|
|
||||||
let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
|
let result = if status == SCStatusCodes::SessionEstablishmentSuccess {
|
||||||
// Get the keys
|
// Get the keys
|
||||||
let ke = ke.ok_or(ErrorCode::Invalid)?;
|
let ke = ke.ok_or(ErrorCode::Invalid)?;
|
||||||
let mut session_keys: [u8; 48] = [0; 48];
|
let mut session_keys: [u8; 48] = [0; 48];
|
||||||
|
@ -194,22 +194,22 @@ impl<'a> Pake<'a> {
|
||||||
.att_challenge
|
.att_challenge
|
||||||
.copy_from_slice(&session_keys[32..48]);
|
.copy_from_slice(&session_keys[32..48]);
|
||||||
|
|
||||||
// Queue a transport mgr request to add a new session
|
Ok(clone_data)
|
||||||
Some(clone_data)
|
|
||||||
} else {
|
} else {
|
||||||
None
|
Err(status)
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(clone_data) = clone_data {
|
let status = match result {
|
||||||
// TODO: Handle NoSpace
|
Ok(clone_data) => {
|
||||||
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
|
exchange.clone_session(tx, &clone_data).await?;
|
||||||
|
self.pase.borrow_mut().disable_pase_session(mdns)?;
|
||||||
|
|
||||||
self.pase.borrow_mut().disable_pase_session(mdns)?;
|
SCStatusCodes::SessionEstablishmentSuccess
|
||||||
}
|
}
|
||||||
|
Err(status) => status,
|
||||||
|
};
|
||||||
|
|
||||||
complete_with_status(exchange, tx, status_code, None).await?;
|
complete_with_status(exchange, tx, status, None).await
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -273,7 +273,7 @@ impl<'a> Pake<'a> {
|
||||||
let mut our_random: [u8; 32] = [0; 32];
|
let mut our_random: [u8; 32] = [0; 32];
|
||||||
(self.pase.borrow().rand)(&mut our_random);
|
(self.pase.borrow().rand)(&mut our_random);
|
||||||
|
|
||||||
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
|
let local_sessid = exchange.get_next_sess_id();
|
||||||
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
|
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
|
||||||
spake2p.set_app_data(spake2p_data);
|
spake2p.set_app_data(spake2p_data);
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,9 @@ use embassy_time::{Duration, Timer};
|
||||||
|
|
||||||
use log::{error, info, warn};
|
use log::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::interaction_model::core::IMStatusCode;
|
||||||
|
use crate::secure_channel::common::SCStatusCodes;
|
||||||
|
use crate::secure_channel::status_report::{create_status_report, GeneralCode};
|
||||||
use crate::utils::select::Notification;
|
use crate::utils::select::Notification;
|
||||||
use crate::CommissioningData;
|
use crate::CommissioningData;
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -41,6 +44,7 @@ use crate::{
|
||||||
Matter,
|
Matter,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::exchange::SessionId;
|
||||||
use super::{
|
use super::{
|
||||||
exchange::{
|
exchange::{
|
||||||
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
|
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
|
||||||
|
@ -97,7 +101,7 @@ impl RunBuffers {
|
||||||
pub struct PacketBuffers {
|
pub struct PacketBuffers {
|
||||||
tx: [TxBuf; MAX_EXCHANGES],
|
tx: [TxBuf; MAX_EXCHANGES],
|
||||||
rx: [RxBuf; MAX_EXCHANGES],
|
rx: [RxBuf; MAX_EXCHANGES],
|
||||||
sx: [SxBuf; MAX_EXCHANGES],
|
sx: [SxBuf; MAX_EXCHANGES + 1],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PacketBuffers {
|
impl PacketBuffers {
|
||||||
|
@ -107,7 +111,7 @@ impl PacketBuffers {
|
||||||
|
|
||||||
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
|
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
|
||||||
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
|
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
|
||||||
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES];
|
const SX_INIT: [SxBuf; MAX_EXCHANGES + 1] = [Self::SX_ELEM; MAX_EXCHANGES + 1];
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub const fn new() -> Self {
|
pub const fn new() -> Self {
|
||||||
|
@ -266,7 +270,12 @@ impl<'a> Matter<'a> {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel));
|
let mut rx = pin!(self.handle_rx_multiplex(
|
||||||
|
rx_pipe,
|
||||||
|
unsafe { buffers.sx[MAX_EXCHANGES].assume_init_mut() },
|
||||||
|
construction_notification,
|
||||||
|
&channel,
|
||||||
|
));
|
||||||
|
|
||||||
let result = select(&mut rx, select_slice(&mut handlers)).await;
|
let result = select(&mut rx, select_slice(&mut handlers)).await;
|
||||||
|
|
||||||
|
@ -291,7 +300,7 @@ impl<'a> Matter<'a> {
|
||||||
if data.chunk.is_none() {
|
if data.chunk.is_none() {
|
||||||
let mut tx = alloc!(Packet::new_tx(data.buf));
|
let mut tx = alloc!(Packet::new_tx(data.buf));
|
||||||
|
|
||||||
if self.pull_tx(&mut tx).await? {
|
if self.pull_tx(&mut tx)? {
|
||||||
data.chunk = Some(Chunk {
|
data.chunk = Some(Chunk {
|
||||||
start: tx.get_writebuf()?.get_start(),
|
start: tx.get_writebuf()?.get_start(),
|
||||||
end: tx.get_writebuf()?.get_tail(),
|
end: tx.get_writebuf()?.get_tail(),
|
||||||
|
@ -315,12 +324,15 @@ impl<'a> Matter<'a> {
|
||||||
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
|
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
|
||||||
&'t self,
|
&'t self,
|
||||||
rx_pipe: &Pipe<'_>,
|
rx_pipe: &Pipe<'_>,
|
||||||
|
sts_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
|
||||||
construction_notification: &'e Notification,
|
construction_notification: &'e Notification,
|
||||||
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
|
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
't: 'e,
|
't: 'e,
|
||||||
{
|
{
|
||||||
|
let mut sts_tx = alloc!(Packet::new_tx(sts_buf));
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
info!("Transport: waiting for incoming packets");
|
info!("Transport: waiting for incoming packets");
|
||||||
|
|
||||||
|
@ -331,8 +343,9 @@ impl<'a> Matter<'a> {
|
||||||
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
|
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
|
||||||
rx.peer = chunk.addr;
|
rx.peer = chunk.addr;
|
||||||
|
|
||||||
if let Some(exchange_ctr) =
|
if let Some(exchange_ctr) = self
|
||||||
self.process_rx(construction_notification, &mut rx)?
|
.process_rx(construction_notification, &mut rx, &mut sts_tx)
|
||||||
|
.await?
|
||||||
{
|
{
|
||||||
let exchange_id = exchange_ctr.id().clone();
|
let exchange_id = exchange_ctr.id().clone();
|
||||||
|
|
||||||
|
@ -444,24 +457,39 @@ impl<'a> Matter<'a> {
|
||||||
self.session_mgr.borrow_mut().reset();
|
self.session_mgr.borrow_mut().reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn process_rx<'r>(
|
pub async fn process_rx<'r>(
|
||||||
&'r self,
|
&'r self,
|
||||||
construction_notification: &'r Notification,
|
construction_notification: &'r Notification,
|
||||||
src_rx: &mut Packet<'_>,
|
src_rx: &mut Packet<'_>,
|
||||||
|
sts_tx: &mut Packet<'_>,
|
||||||
) -> Result<Option<ExchangeCtr<'r>>, Error> {
|
) -> Result<Option<ExchangeCtr<'r>>, Error> {
|
||||||
|
src_rx.plain_hdr_decode()?;
|
||||||
|
|
||||||
self.purge()?;
|
self.purge()?;
|
||||||
|
|
||||||
|
let (exchange_index, new) = loop {
|
||||||
|
let result = self.assign_exchange(&mut self.exchanges.borrow_mut(), src_rx);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(e) => match e.code() {
|
||||||
|
ErrorCode::Duplicate => {
|
||||||
|
self.send_notification.signal(());
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
// TODO: NoSession, NoExchange and others
|
||||||
|
ErrorCode::NoSpaceSessions => self.evict_session(sts_tx).await?,
|
||||||
|
ErrorCode::NoSpaceExchanges => {
|
||||||
|
self.send_busy(src_rx, sts_tx).await?;
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
_ => break Err(e),
|
||||||
|
},
|
||||||
|
other => break other,
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
let mut exchanges = self.exchanges.borrow_mut();
|
let mut exchanges = self.exchanges.borrow_mut();
|
||||||
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) {
|
let ctx = &mut exchanges[exchange_index];
|
||||||
Ok((ctx, new)) => (ctx, new),
|
|
||||||
Err(e) => match e.code() {
|
|
||||||
ErrorCode::Duplicate => {
|
|
||||||
self.send_notification.signal(());
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
_ => Err(e)?,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
src_rx.log("Got packet");
|
src_rx.log("Got packet");
|
||||||
|
|
||||||
|
@ -516,6 +544,8 @@ impl<'a> Matter<'a> {
|
||||||
ExchangeState::ExchangeRecv {
|
ExchangeState::ExchangeRecv {
|
||||||
rx, notification, ..
|
rx, notification, ..
|
||||||
} => {
|
} => {
|
||||||
|
// TODO: Handle Busy status codes
|
||||||
|
|
||||||
let rx = unsafe { rx.as_mut() }.unwrap();
|
let rx = unsafe { rx.as_mut() }.unwrap();
|
||||||
rx.load(src_rx)?;
|
rx.load(src_rx)?;
|
||||||
|
|
||||||
|
@ -572,12 +602,24 @@ impl<'a> Matter<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> {
|
pub fn pull_tx(&self, dest_tx: &mut Packet) -> Result<bool, Error> {
|
||||||
self.purge()?;
|
self.purge()?;
|
||||||
|
|
||||||
|
let mut ephemeral = self.ephemeral.borrow_mut();
|
||||||
let mut exchanges = self.exchanges.borrow_mut();
|
let mut exchanges = self.exchanges.borrow_mut();
|
||||||
|
|
||||||
let ctx = exchanges.iter_mut().find(|ctx| {
|
self.pull_tx_exchanges(ephemeral.iter_mut().chain(exchanges.iter_mut()), dest_tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pull_tx_exchanges<'i, I>(
|
||||||
|
&self,
|
||||||
|
mut exchanges: I,
|
||||||
|
dest_tx: &mut Packet,
|
||||||
|
) -> Result<bool, Error>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = &'i mut ExchangeCtx>,
|
||||||
|
{
|
||||||
|
let ctx = exchanges.find(|ctx| {
|
||||||
matches!(
|
matches!(
|
||||||
&ctx.state,
|
&ctx.state,
|
||||||
ExchangeState::Acknowledge { .. }
|
ExchangeState::Acknowledge { .. }
|
||||||
|
@ -629,10 +671,15 @@ impl<'a> Matter<'a> {
|
||||||
let tx = unsafe { tx.as_ref() }.unwrap();
|
let tx = unsafe { tx.as_ref() }.unwrap();
|
||||||
dest_tx.load(tx)?;
|
dest_tx.load(tx)?;
|
||||||
|
|
||||||
*state = ExchangeState::CompleteAcknowledge {
|
if dest_tx.is_reliable() {
|
||||||
_tx: tx as *const _,
|
*state = ExchangeState::CompleteAcknowledge {
|
||||||
notification: *notification,
|
_tx: tx as *const _,
|
||||||
};
|
notification: *notification,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
unsafe { notification.as_ref() }.unwrap().signal(());
|
||||||
|
ctx.state = ExchangeState::Closed;
|
||||||
|
}
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
@ -648,8 +695,6 @@ impl<'a> Matter<'a> {
|
||||||
|
|
||||||
if send {
|
if send {
|
||||||
dest_tx.log("Sending packet");
|
dest_tx.log("Sending packet");
|
||||||
|
|
||||||
self.pre_send(ctx, dest_tx)?;
|
|
||||||
self.notify_changed();
|
self.notify_changed();
|
||||||
|
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
|
@ -675,13 +720,88 @@ impl<'a> Matter<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn post_recv<'r>(
|
pub(crate) async fn evict_session(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||||
&self,
|
let sess_index = self.session_mgr.borrow().get_session_for_eviction();
|
||||||
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
|
if let Some(sess_index) = sess_index {
|
||||||
rx: &mut Packet<'_>,
|
let ctx = {
|
||||||
) -> Result<(&'r mut ExchangeCtx, bool), Error> {
|
create_status_report(
|
||||||
rx.plain_hdr_decode()?;
|
tx,
|
||||||
|
GeneralCode::Success,
|
||||||
|
PROTO_ID_SECURE_CHANNEL as _,
|
||||||
|
SCStatusCodes::CloseSession as _,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut session_mgr = self.session_mgr.borrow_mut();
|
||||||
|
let session_id = session_mgr.mut_by_index(sess_index).unwrap().id();
|
||||||
|
warn!("Evicting session: {:?}", session_id);
|
||||||
|
|
||||||
|
let ctx = ExchangeCtx::prep_ephemeral(session_id, &mut session_mgr, None, tx)?;
|
||||||
|
|
||||||
|
session_mgr.remove(sess_index);
|
||||||
|
|
||||||
|
ctx
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_ephemeral(ctx, tx).await
|
||||||
|
} else {
|
||||||
|
Err(ErrorCode::NoSpaceSessions.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_busy(&self, rx: &Packet<'_>, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||||
|
warn!("Sending Busy as all exchanges are occupied");
|
||||||
|
|
||||||
|
create_status_report(
|
||||||
|
tx,
|
||||||
|
GeneralCode::Busy,
|
||||||
|
rx.get_proto_id() as _,
|
||||||
|
if rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
|
||||||
|
SCStatusCodes::Busy as _
|
||||||
|
} else {
|
||||||
|
IMStatusCode::Busy as _
|
||||||
|
},
|
||||||
|
None, // TODO: ms
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let ctx = ExchangeCtx::prep_ephemeral(
|
||||||
|
SessionId::load(rx),
|
||||||
|
&mut self.session_mgr.borrow_mut(),
|
||||||
|
Some(rx),
|
||||||
|
tx,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
self.send_ephemeral(ctx, tx).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_ephemeral(&self, mut ctx: ExchangeCtx, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||||
|
let _guard = self.ephemeral_mutex.lock().await;
|
||||||
|
|
||||||
|
let notification = Notification::new();
|
||||||
|
|
||||||
|
let tx: &'static mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||||
|
|
||||||
|
ctx.state = ExchangeState::Complete {
|
||||||
|
tx,
|
||||||
|
notification: ¬ification,
|
||||||
|
};
|
||||||
|
|
||||||
|
*self.ephemeral.borrow_mut() = Some(ctx);
|
||||||
|
|
||||||
|
self.send_notification.signal(());
|
||||||
|
|
||||||
|
notification.wait().await;
|
||||||
|
|
||||||
|
*self.ephemeral.borrow_mut() = None;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assign_exchange(
|
||||||
|
&self,
|
||||||
|
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
|
||||||
|
rx: &mut Packet<'_>,
|
||||||
|
) -> Result<(usize, bool), Error> {
|
||||||
// Get the session
|
// Get the session
|
||||||
|
|
||||||
let mut session_mgr = self.session_mgr.borrow_mut();
|
let mut session_mgr = self.session_mgr.borrow_mut();
|
||||||
|
@ -693,8 +813,7 @@ impl<'a> Matter<'a> {
|
||||||
session.recv(self.epoch, rx)?;
|
session.recv(self.epoch, rx)?;
|
||||||
|
|
||||||
// Get the exchange
|
// Get the exchange
|
||||||
// TODO: Handle out of space
|
let (exchange_index, new) = Self::register(
|
||||||
let (exch, new) = Self::register(
|
|
||||||
exchanges,
|
exchanges,
|
||||||
ExchangeId::load(rx),
|
ExchangeId::load(rx),
|
||||||
Role::complementary(rx.proto.is_initiator()),
|
Role::complementary(rx.proto.is_initiator()),
|
||||||
|
@ -703,32 +822,9 @@ impl<'a> Matter<'a> {
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Message Reliability Protocol
|
// Message Reliability Protocol
|
||||||
exch.mrp.recv(rx, self.epoch)?;
|
exchanges[exchange_index].mrp.recv(rx, self.epoch)?;
|
||||||
|
|
||||||
Ok((exch, new))
|
Ok((exchange_index, new))
|
||||||
}
|
|
||||||
|
|
||||||
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
|
|
||||||
let mut session_mgr = self.session_mgr.borrow_mut();
|
|
||||||
let sess_index = session_mgr
|
|
||||||
.get(
|
|
||||||
ctx.id.session_id.id,
|
|
||||||
ctx.id.session_id.peer_addr,
|
|
||||||
ctx.id.session_id.peer_nodeid,
|
|
||||||
ctx.id.session_id.is_encrypted,
|
|
||||||
)
|
|
||||||
.ok_or(ErrorCode::NoSession)?;
|
|
||||||
|
|
||||||
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
|
||||||
|
|
||||||
tx.proto.exch_id = ctx.id.id;
|
|
||||||
if ctx.role == Role::Initiator {
|
|
||||||
tx.proto.set_initiator();
|
|
||||||
}
|
|
||||||
|
|
||||||
session.pre_send(tx)?;
|
|
||||||
ctx.mrp.pre_send(tx)?;
|
|
||||||
session_mgr.send(sess_index, tx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register(
|
fn register(
|
||||||
|
@ -736,7 +832,7 @@ impl<'a> Matter<'a> {
|
||||||
id: ExchangeId,
|
id: ExchangeId,
|
||||||
role: Role,
|
role: Role,
|
||||||
create_new: bool,
|
create_new: bool,
|
||||||
) -> Result<(&mut ExchangeCtx, bool), Error> {
|
) -> Result<(usize, bool), Error> {
|
||||||
let exchange_index = exchanges
|
let exchange_index = exchanges
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
@ -745,7 +841,7 @@ impl<'a> Matter<'a> {
|
||||||
if let Some(exchange_index) = exchange_index {
|
if let Some(exchange_index) = exchange_index {
|
||||||
let exchange = &mut exchanges[exchange_index];
|
let exchange = &mut exchanges[exchange_index];
|
||||||
if exchange.role == role {
|
if exchange.role == role {
|
||||||
Ok((exchange, false))
|
Ok((exchange_index, false))
|
||||||
} else {
|
} else {
|
||||||
Err(ErrorCode::NoExchange.into())
|
Err(ErrorCode::NoExchange.into())
|
||||||
}
|
}
|
||||||
|
@ -759,9 +855,11 @@ impl<'a> Matter<'a> {
|
||||||
state: ExchangeState::Active,
|
state: ExchangeState::Active,
|
||||||
};
|
};
|
||||||
|
|
||||||
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?;
|
exchanges
|
||||||
|
.push(exchange)
|
||||||
|
.map_err(|_| ErrorCode::NoSpaceExchanges)?;
|
||||||
|
|
||||||
Ok((exchanges.iter_mut().next_back().unwrap(), true))
|
Ok((exchanges.len() - 1, true))
|
||||||
} else {
|
} else {
|
||||||
Err(ErrorCode::NoExchange.into())
|
Err(ErrorCode::NoExchange.into())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
acl::Accessor,
|
acl::Accessor,
|
||||||
error::{Error, ErrorCode},
|
error::{Error, ErrorCode},
|
||||||
utils::select::Notification,
|
utils::{epoch::Epoch, select::Notification},
|
||||||
Matter,
|
Matter,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ use super::{
|
||||||
mrp::ReliableMessage,
|
mrp::ReliableMessage,
|
||||||
network::Address,
|
network::Address,
|
||||||
packet::Packet,
|
packet::Packet,
|
||||||
session::{Session, SessionMgr},
|
session::{CloneData, Session, SessionMgr},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const MAX_EXCHANGES: usize = 8;
|
pub const MAX_EXCHANGES: usize = 8;
|
||||||
|
@ -46,6 +46,101 @@ impl ExchangeCtx {
|
||||||
) -> Option<&'r mut ExchangeCtx> {
|
) -> Option<&'r mut ExchangeCtx> {
|
||||||
exchanges.iter_mut().find(|exchange| exchange.id == *id)
|
exchanges.iter_mut().find(|exchange| exchange.id == *id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_ephemeral(session_id: SessionId, reply_to: Option<&Packet<'_>>) -> Self {
|
||||||
|
Self {
|
||||||
|
id: ExchangeId {
|
||||||
|
id: if let Some(rx) = reply_to {
|
||||||
|
rx.proto.exch_id
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
},
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
},
|
||||||
|
role: if reply_to.is_some() {
|
||||||
|
Role::Responder
|
||||||
|
} else {
|
||||||
|
Role::Initiator
|
||||||
|
},
|
||||||
|
mrp: ReliableMessage::new(),
|
||||||
|
state: ExchangeState::Active,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn prep_ephemeral(
|
||||||
|
session_id: SessionId,
|
||||||
|
session_mgr: &mut SessionMgr,
|
||||||
|
reply_to: Option<&Packet<'_>>,
|
||||||
|
tx: &mut Packet<'_>,
|
||||||
|
) -> Result<ExchangeCtx, Error> {
|
||||||
|
let mut ctx = Self::new_ephemeral(session_id.clone(), reply_to);
|
||||||
|
|
||||||
|
let sess_index = session_mgr.get(
|
||||||
|
session_id.id,
|
||||||
|
session_id.peer_addr,
|
||||||
|
session_id.peer_nodeid,
|
||||||
|
session_id.is_encrypted,
|
||||||
|
);
|
||||||
|
|
||||||
|
let epoch = session_mgr.epoch;
|
||||||
|
let rand = session_mgr.rand;
|
||||||
|
|
||||||
|
if let Some(rx) = reply_to {
|
||||||
|
ctx.mrp.recv(rx, epoch)?;
|
||||||
|
} else {
|
||||||
|
tx.proto.set_initiator();
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.unset_reliable();
|
||||||
|
|
||||||
|
if let Some(sess_index) = sess_index {
|
||||||
|
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
||||||
|
ctx.pre_send_sess(session, tx, epoch)?;
|
||||||
|
} else {
|
||||||
|
let mut session =
|
||||||
|
Session::new(session_id.peer_addr, session_id.peer_nodeid, epoch, rand);
|
||||||
|
ctx.pre_send_sess(&mut session, tx, epoch)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pre_send(
|
||||||
|
&mut self,
|
||||||
|
session_mgr: &mut SessionMgr,
|
||||||
|
tx: &mut Packet,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let epoch = session_mgr.epoch;
|
||||||
|
|
||||||
|
let sess_index = session_mgr
|
||||||
|
.get(
|
||||||
|
self.id.session_id.id,
|
||||||
|
self.id.session_id.peer_addr,
|
||||||
|
self.id.session_id.peer_nodeid,
|
||||||
|
self.id.session_id.is_encrypted,
|
||||||
|
)
|
||||||
|
.ok_or(ErrorCode::NoSession)?;
|
||||||
|
|
||||||
|
let session = session_mgr.mut_by_index(sess_index).unwrap();
|
||||||
|
|
||||||
|
self.pre_send_sess(session, tx, epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pre_send_sess(
|
||||||
|
&mut self,
|
||||||
|
session: &mut Session,
|
||||||
|
tx: &mut Packet,
|
||||||
|
epoch: Epoch,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
tx.proto.exch_id = self.id.id;
|
||||||
|
if self.role == Role::Initiator {
|
||||||
|
tx.proto.set_initiator();
|
||||||
|
}
|
||||||
|
|
||||||
|
session.pre_send(tx)?;
|
||||||
|
self.mrp.pre_send(tx)?;
|
||||||
|
session.send(epoch, tx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -192,15 +287,6 @@ impl<'a> Exchange<'a> {
|
||||||
self.with_session_mut(|sess| f(sess))
|
self.with_session_mut(|sess| f(sess))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
|
|
||||||
where
|
|
||||||
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
|
|
||||||
{
|
|
||||||
let mut session_mgr = self.matter.session_mgr.borrow_mut();
|
|
||||||
|
|
||||||
f(&mut session_mgr)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn acknowledge(&mut self) -> Result<(), Error> {
|
pub async fn acknowledge(&mut self) -> Result<(), Error> {
|
||||||
let wait = self.with_ctx_mut(|_self, ctx| {
|
let wait = self.with_ctx_mut(|_self, ctx| {
|
||||||
if !matches!(ctx.state, ExchangeState::Active) {
|
if !matches!(ctx.state, ExchangeState::Active) {
|
||||||
|
@ -226,8 +312,12 @@ impl<'a> Exchange<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> {
|
pub async fn exchange(
|
||||||
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
|
&mut self,
|
||||||
|
tx: &mut Packet<'_>,
|
||||||
|
rx: &mut Packet<'_>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||||
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
|
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
|
||||||
|
|
||||||
self.with_ctx_mut(|_self, ctx| {
|
self.with_ctx_mut(|_self, ctx| {
|
||||||
|
@ -235,6 +325,9 @@ impl<'a> Exchange<'a> {
|
||||||
Err(ErrorCode::NoExchange)?;
|
Err(ErrorCode::NoExchange)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
|
||||||
|
ctx.pre_send(&mut session_mgr, tx)?;
|
||||||
|
|
||||||
ctx.state = ExchangeState::ExchangeSend {
|
ctx.state = ExchangeState::ExchangeSend {
|
||||||
tx: tx as *const _,
|
tx: tx as *const _,
|
||||||
rx: rx as *mut _,
|
rx: rx as *mut _,
|
||||||
|
@ -250,18 +343,21 @@ impl<'a> Exchange<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> {
|
pub async fn complete(mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||||
self.send_complete(tx).await
|
self.send_complete(tx).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> {
|
pub async fn send_complete(&mut self, tx: &mut Packet<'_>) -> Result<(), Error> {
|
||||||
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
|
let tx: &mut Packet<'static> = unsafe { core::mem::transmute(tx) };
|
||||||
|
|
||||||
self.with_ctx_mut(|_self, ctx| {
|
self.with_ctx_mut(|_self, ctx| {
|
||||||
if !matches!(ctx.state, ExchangeState::Active) {
|
if !matches!(ctx.state, ExchangeState::Active) {
|
||||||
Err(ErrorCode::NoExchange)?;
|
Err(ErrorCode::NoExchange)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
|
||||||
|
ctx.pre_send(&mut session_mgr, tx)?;
|
||||||
|
|
||||||
ctx.state = ExchangeState::Complete {
|
ctx.state = ExchangeState::Complete {
|
||||||
tx: tx as *const _,
|
tx: tx as *const _,
|
||||||
notification: &_self.notification as *const _,
|
notification: &_self.notification as *const _,
|
||||||
|
@ -276,6 +372,31 @@ impl<'a> Exchange<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_next_sess_id(&mut self) -> u16 {
|
||||||
|
self.matter.session_mgr.borrow_mut().get_next_sess_id()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn clone_session(
|
||||||
|
&mut self,
|
||||||
|
tx: &mut Packet<'_>,
|
||||||
|
clone_data: &CloneData,
|
||||||
|
) -> Result<usize, Error> {
|
||||||
|
loop {
|
||||||
|
let result = self
|
||||||
|
.matter
|
||||||
|
.session_mgr
|
||||||
|
.borrow_mut()
|
||||||
|
.clone_session(clone_data);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(err) if err.code() == ErrorCode::NoSpaceSessions => {
|
||||||
|
self.matter.evict_session(tx).await?
|
||||||
|
}
|
||||||
|
other => break other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
|
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
|
||||||
where
|
where
|
||||||
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,
|
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,
|
||||||
|
|
|
@ -19,13 +19,13 @@ use crate::data_model::sdm::noc::NocData;
|
||||||
use crate::utils::epoch::Epoch;
|
use crate::utils::epoch::Epoch;
|
||||||
use crate::utils::rand::Rand;
|
use crate::utils::rand::Rand;
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use core::ops::{Deref, DerefMut};
|
|
||||||
use core::time::Duration;
|
use core::time::Duration;
|
||||||
|
|
||||||
use crate::{error::*, transport::plain_hdr};
|
use crate::{error::*, transport::plain_hdr};
|
||||||
use log::info;
|
use log::info;
|
||||||
|
|
||||||
use super::dedup::RxCtrState;
|
use super::dedup::RxCtrState;
|
||||||
|
use super::exchange::SessionId;
|
||||||
use super::{network::Address, packet::Packet};
|
use super::{network::Address, packet::Packet};
|
||||||
|
|
||||||
pub const MAX_CAT_IDS_PER_NOC: usize = 3;
|
pub const MAX_CAT_IDS_PER_NOC: usize = 3;
|
||||||
|
@ -151,6 +151,15 @@ impl Session {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> SessionId {
|
||||||
|
SessionId {
|
||||||
|
id: self.local_sess_id,
|
||||||
|
peer_addr: self.peer_addr,
|
||||||
|
peer_nodeid: self.peer_nodeid,
|
||||||
|
is_encrypted: self.is_encrypted(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_noc_data(&mut self, data: NocData) {
|
pub fn set_noc_data(&mut self, data: NocData) {
|
||||||
self.data = Some(data);
|
self.data = Some(data);
|
||||||
}
|
}
|
||||||
|
@ -251,7 +260,7 @@ impl Session {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
|
pub(crate) fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> {
|
||||||
self.last_use = epoch();
|
self.last_use = epoch();
|
||||||
|
|
||||||
tx.proto_encode(
|
tx.proto_encode(
|
||||||
|
@ -291,8 +300,8 @@ pub const MAX_SESSIONS: usize = 16;
|
||||||
pub struct SessionMgr {
|
pub struct SessionMgr {
|
||||||
next_sess_id: u16,
|
next_sess_id: u16,
|
||||||
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
|
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
|
||||||
epoch: Epoch,
|
pub(crate) epoch: Epoch,
|
||||||
rand: Rand,
|
pub(crate) rand: Rand,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionMgr {
|
impl SessionMgr {
|
||||||
|
@ -327,7 +336,11 @@ impl SessionMgr {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the currently selected id doesn't match any existing session
|
// Ensure the currently selected id doesn't match any existing session
|
||||||
if self.get_with_id(next_sess_id).is_none() {
|
if self.sessions.iter().all(|sess| {
|
||||||
|
sess.as_ref()
|
||||||
|
.map(|sess| sess.get_local_sess_id() != next_sess_id)
|
||||||
|
.unwrap_or(true)
|
||||||
|
}) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -381,12 +394,12 @@ impl SessionMgr {
|
||||||
} else if self.sessions.len() < MAX_SESSIONS {
|
} else if self.sessions.len() < MAX_SESSIONS {
|
||||||
self.sessions
|
self.sessions
|
||||||
.push(Some(session))
|
.push(Some(session))
|
||||||
.map_err(|_| ErrorCode::NoSpace)
|
.map_err(|_| ErrorCode::NoSpaceSessions)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Ok(self.sessions.len() - 1)
|
Ok(self.sessions.len() - 1)
|
||||||
} else {
|
} else {
|
||||||
Err(ErrorCode::NoSpace.into())
|
Err(ErrorCode::NoSpaceSessions.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,14 +432,6 @@ impl SessionMgr {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_with_id(&mut self, sess_id: u16) -> Option<SessionHandle> {
|
|
||||||
let index = self
|
|
||||||
.sessions
|
|
||||||
.iter_mut()
|
|
||||||
.position(|x| x.as_ref().map(|s| s.local_sess_id) == Some(sess_id))?;
|
|
||||||
Some(self.get_session_handle(index))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_or_add(
|
pub fn get_or_add(
|
||||||
&mut self,
|
&mut self,
|
||||||
sess_id: u16,
|
sess_id: u16,
|
||||||
|
@ -472,13 +477,6 @@ impl SessionMgr {
|
||||||
.ok_or(ErrorCode::NoSession)?
|
.ok_or(ErrorCode::NoSession)?
|
||||||
.send(self.epoch, tx)
|
.send(self.epoch, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle {
|
|
||||||
SessionHandle {
|
|
||||||
sess_mgr: self,
|
|
||||||
sess_idx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for SessionMgr {
|
impl fmt::Display for SessionMgr {
|
||||||
|
@ -492,45 +490,6 @@ impl fmt::Display for SessionMgr {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SessionHandle<'a> {
|
|
||||||
pub(crate) sess_mgr: &'a mut SessionMgr,
|
|
||||||
sess_idx: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> SessionHandle<'a> {
|
|
||||||
pub fn session(&self) -> &Session {
|
|
||||||
self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn session_mut(&mut self) -> &mut Session {
|
|
||||||
self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reserve_new_sess_id(&mut self) -> u16 {
|
|
||||||
self.sess_mgr.get_next_sess_id()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> {
|
|
||||||
self.sess_mgr.send(self.sess_idx, tx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Deref for SessionHandle<'a> {
|
|
||||||
type Target = Session;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
// There is no other option but to panic if this is None
|
|
||||||
self.session()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> DerefMut for SessionHandle<'a> {
|
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
||||||
// There is no other option but to panic if this is None
|
|
||||||
self.session_mut()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
@ -545,12 +504,12 @@ mod tests {
|
||||||
fn test_next_sess_id_doesnt_reuse() {
|
fn test_next_sess_id_doesnt_reuse() {
|
||||||
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
||||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||||
let mut sess = sm.get_session_handle(sess_idx);
|
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||||
sess.set_local_sess_id(1);
|
sess.set_local_sess_id(1);
|
||||||
assert_eq!(sm.get_next_sess_id(), 2);
|
assert_eq!(sm.get_next_sess_id(), 2);
|
||||||
assert_eq!(sm.get_next_sess_id(), 3);
|
assert_eq!(sm.get_next_sess_id(), 3);
|
||||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||||
let mut sess = sm.get_session_handle(sess_idx);
|
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||||
sess.set_local_sess_id(4);
|
sess.set_local_sess_id(4);
|
||||||
assert_eq!(sm.get_next_sess_id(), 5);
|
assert_eq!(sm.get_next_sess_id(), 5);
|
||||||
}
|
}
|
||||||
|
@ -559,7 +518,7 @@ mod tests {
|
||||||
fn test_next_sess_id_overflows() {
|
fn test_next_sess_id_overflows() {
|
||||||
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
let mut sm = SessionMgr::new(dummy_epoch, dummy_rand);
|
||||||
let sess_idx = sm.add(Address::default(), None).unwrap();
|
let sess_idx = sm.add(Address::default(), None).unwrap();
|
||||||
let mut sess = sm.get_session_handle(sess_idx);
|
let sess = sm.mut_by_index(sess_idx).unwrap();
|
||||||
sess.set_local_sess_id(1);
|
sess.set_local_sess_id(1);
|
||||||
assert_eq!(sm.get_next_sess_id(), 2);
|
assert_eq!(sm.get_next_sess_id(), 2);
|
||||||
sm.next_sess_id = 65534;
|
sm.next_sess_id = 65534;
|
||||||
|
|
Loading…
Add table
Reference in a new issue