Sequential Exchange API

This commit is contained in:
ivmarkov 2023-06-10 14:01:35 +00:00
parent b40a0afbd0
commit a2a5691ade
36 changed files with 3290 additions and 3079 deletions

View file

@ -23,20 +23,15 @@ use log::info;
use matter::core::{CommissioningData, Matter}; use matter::core::{CommissioningData, Matter};
use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::cluster_on_off; use matter::data_model::cluster_on_off;
use matter::data_model::core::DataModel;
use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT;
use matter::data_model::objects::*; use matter::data_model::objects::*;
use matter::data_model::root_endpoint; use matter::data_model::root_endpoint;
use matter::data_model::system_model::descriptor; use matter::data_model::system_model::descriptor;
use matter::error::Error; use matter::error::Error;
use matter::interaction_model::core::InteractionModel;
use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; use matter::mdns::{DefaultMdns, DefaultMdnsRunner};
use matter::secure_channel::spake2p::VerifierData; use matter::secure_channel::spake2p::VerifierData;
use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use matter::transport::network::{Ipv4Addr, Ipv6Addr};
use matter::transport::{ use matter::transport::runner::{RxBuf, TransportRunner, TxBuf};
core::RecvAction, core::Transport, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE,
udp::UdpListener,
};
use matter::utils::select::EitherUnwrap; use matter::utils::select::EitherUnwrap;
mod dev_att; mod dev_att;
@ -44,7 +39,7 @@ mod dev_att;
#[cfg(feature = "std")] #[cfg(feature = "std")]
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let thread = std::thread::Builder::new() let thread = std::thread::Builder::new()
.stack_size(120 * 1024) .stack_size(140 * 1024)
.spawn(run) .spawn(run)
.unwrap(); .unwrap();
@ -62,10 +57,10 @@ fn run() -> Result<(), Error> {
initialize_logger(); initialize_logger();
info!( info!(
"Matter memory: mDNS={}, Matter={}, Transport={}", "Matter memory: mDNS={}, Matter={}, TransportRunner={}",
core::mem::size_of::<DefaultMdns>(), core::mem::size_of::<DefaultMdns>(),
core::mem::size_of::<Matter>(), core::mem::size_of::<Matter>(),
core::mem::size_of::<Transport>(), core::mem::size_of::<TransportRunner>(),
); );
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
@ -92,6 +87,8 @@ fn run() -> Result<(), Error> {
let mut mdns_runner = DefaultMdnsRunner::new(&mdns); let mut mdns_runner = DefaultMdnsRunner::new(&mdns);
info!("mDNS initialized: {:p}, {:p}", &mdns, &mdns_runner);
let dev_att = dev_att::HardCodedDevAtt::new(); let dev_att = dev_att::HardCodedDevAtt::new();
#[cfg(feature = "std")] #[cfg(feature = "std")]
@ -118,36 +115,25 @@ fn run() -> Result<(), Error> {
matter::MATTER_PORT, matter::MATTER_PORT,
); );
let psm_path = std::env::temp_dir().join("matter-iot"); info!("Matter initialized: {:p}", &matter);
info!("Persisting from/to {}", psm_path.display());
#[cfg(all(feature = "std", not(target_os = "espidf")))] let mut runner = TransportRunner::new(&matter);
let psm = matter::persist::FilePsm::new(psm_path)?;
let mut buf = [0; 4096]; info!("Transport Runner initialized: {:p}", &runner);
let buf = &mut buf;
#[cfg(all(feature = "std", not(target_os = "espidf")))] let mut tx_buf = TxBuf::uninit();
{ let mut rx_buf = RxBuf::uninit();
if let Some(data) = psm.load("acls", buf)? {
matter.load_acls(data)?;
}
if let Some(data) = psm.load("fabrics", buf)? { // #[cfg(all(feature = "std", not(target_os = "espidf")))]
matter.load_fabrics(data)?; // {
} // if let Some(data) = psm.load("acls", buf)? {
} // matter.load_acls(data)?;
// }
let mut transport = Transport::new(&matter); // if let Some(data) = psm.load("fabrics", buf)? {
// matter.load_fabrics(data)?;
transport.start( // }
CommissioningData { // }
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456, *matter.borrow()),
discriminator: 250,
},
buf,
)?;
let node = Node { let node = Node {
id: 0, id: 0,
@ -161,69 +147,48 @@ fn run() -> Result<(), Error> {
], ],
}; };
let mut handler = handler(&matter); let handler = HandlerCompat(handler(&matter));
let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); let matter = &matter;
let node = &node;
let mut rx_buf = [0; MAX_RX_BUF_SIZE]; let handler = &handler;
let mut tx_buf = [0; MAX_TX_BUF_SIZE]; let runner = &mut runner;
let im = &mut im;
let mdns_runner = &mut mdns_runner;
let transport = &mut transport;
let rx_buf = &mut rx_buf;
let tx_buf = &mut tx_buf; let tx_buf = &mut tx_buf;
let rx_buf = &mut rx_buf;
let mut io_fut = pin!(async move { info!(
// NOTE (no_std): On no_std, the `UdpListener` implementation is a no-op so you might want to "About to run wth node {:p}, handler {:p}, transport runner {:p}, mdns_runner {:p}",
// replace it with your own UDP stack node, handler, runner, &mdns_runner
let udp = UdpListener::new(SocketAddr::new( );
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
matter::MATTER_PORT,
))
.await?;
loop { let mut fut = pin!(async move {
let (len, addr) = udp.recv(rx_buf).await?; // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and
// connect the pipes of the `run` method with your own UDP stack
let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf); let mut transport = pin!(runner.run_udp(
tx_buf,
while let Some(action) = completion.next_action()? { rx_buf,
match action { CommissioningData {
RecvAction::Send(addr, buf) => { // TODO: Hard-coded for now
udp.send(addr.unwrap_udp(), buf).await?; verifier: VerifierData::new_with_pw(123456, *matter.borrow()),
} discriminator: 250,
RecvAction::Interact(mut ctx) => { },
if im.handle(&mut ctx)? && ctx.send()? { &handler,
udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) ));
.await?;
}
}
}
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
{
if let Some(data) = transport.matter().store_fabrics(buf)? {
psm.store("fabrics", data)?;
}
if let Some(data) = transport.matter().store_acls(buf)? {
psm.store("acls", data)?;
}
}
}
#[allow(unreachable_code)]
Ok::<_, matter::error::Error>(())
});
// NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and
// connect the pipes of the `run` method with your own UDP stack // connect the pipes of the `run` method with your own UDP stack
let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await }); let mut mdns = pin!(mdns_runner.run_udp());
let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); select(
&mut transport,
&mut mdns,
//save(transport, &psm),
)
.await
.unwrap()
});
// NOTE: For no_std, replace with your own no_std way of polling the future
#[cfg(feature = "std")] #[cfg(feature = "std")]
smol::block_on(&mut fut)?; smol::block_on(&mut fut)?;
@ -235,7 +200,21 @@ fn run() -> Result<(), Error> {
Ok(()) Ok(())
} }
fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { const NODE: Node<'static> = Node {
id: 0,
endpoints: &[
root_endpoint::endpoint(0),
Endpoint {
id: 1,
device_type: DEV_TYPE_ON_OFF_LIGHT,
clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER],
},
],
};
fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + 'a {
(
NODE,
root_endpoint::handler(0, matter) root_endpoint::handler(0, matter)
.chain( .chain(
1, 1,
@ -246,6 +225,7 @@ fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a {
1, 1,
cluster_on_off::ID, cluster_on_off::ID,
cluster_on_off::OnOffCluster::new(*matter.borrow()), cluster_on_off::OnOffCluster::new(*matter.borrow()),
),
) )
} }

View file

@ -15,12 +15,12 @@
* limitations under the License. * limitations under the License.
*/ */
use core::convert::TryInto; use core::{cell::Cell, convert::TryInto};
use super::objects::*; use super::objects::*;
use crate::{ use crate::{
attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction, attribute_enum, cmd_enter, command_enum, error::Error, tlv::TLVElement,
tlv::TLVElement, utils::rand::Rand, transport::exchange::Exchange, utils::rand::Rand,
}; };
use log::info; use log::info;
use strum::{EnumDiscriminants, FromRepr}; use strum::{EnumDiscriminants, FromRepr};
@ -66,20 +66,20 @@ pub const CLUSTER: Cluster<'static> = Cluster {
pub struct OnOffCluster { pub struct OnOffCluster {
data_ver: Dataver, data_ver: Dataver,
on: bool, on: Cell<bool>,
} }
impl OnOffCluster { impl OnOffCluster {
pub fn new(rand: Rand) -> Self { pub fn new(rand: Rand) -> Self {
Self { Self {
data_ver: Dataver::new(rand), data_ver: Dataver::new(rand),
on: false, on: Cell::new(false),
} }
} }
pub fn set(&mut self, on: bool) { pub fn set(&self, on: bool) {
if self.on != on { if self.on.get() != on {
self.on = on; self.on.set(on);
self.data_ver.changed(); self.data_ver.changed();
} }
} }
@ -90,7 +90,7 @@ impl OnOffCluster {
CLUSTER.read(attr.attr_id, writer) CLUSTER.read(attr.attr_id, writer)
} else { } else {
match attr.attr_id.try_into()? { match attr.attr_id.try_into()? {
Attributes::OnOff(codec) => codec.encode(writer, self.on), Attributes::OnOff(codec) => codec.encode(writer, self.on.get()),
} }
} }
} else { } else {
@ -98,7 +98,7 @@ impl OnOffCluster {
} }
} }
pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
let data = data.with_dataver(self.data_ver.get())?; let data = data.with_dataver(self.data_ver.get())?;
match attr.attr_id.try_into()? { match attr.attr_id.try_into()? {
@ -111,8 +111,8 @@ impl OnOffCluster {
} }
pub fn invoke( pub fn invoke(
&mut self, &self,
transaction: &mut Transaction, _exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
_data: &TLVElement, _data: &TLVElement,
_encoder: CmdDataEncoder, _encoder: CmdDataEncoder,
@ -128,12 +128,10 @@ impl OnOffCluster {
} }
Commands::Toggle => { Commands::Toggle => {
cmd_enter!("Toggle"); cmd_enter!("Toggle");
self.set(!self.on); self.set(!self.on.get());
} }
} }
transaction.complete();
self.data_ver.changed(); self.data_ver.changed();
Ok(()) Ok(())
@ -145,18 +143,18 @@ impl Handler for OnOffCluster {
OnOffCluster::read(self, attr, encoder) OnOffCluster::read(self, attr, encoder)
} }
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
OnOffCluster::write(self, attr, data) OnOffCluster::write(self, attr, data)
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
OnOffCluster::invoke(self, transaction, cmd, data, encoder) OnOffCluster::invoke(self, exchange, cmd, data, encoder)
} }
} }

View file

@ -15,287 +15,127 @@
* limitations under the License. * limitations under the License.
*/ */
use core::cell::RefCell; use core::sync::atomic::{AtomicU32, Ordering};
use super::objects::*; use super::objects::*;
use crate::{ use crate::{
acl::{Accessor, AclMgr}, alloc,
error::*, error::*,
interaction_model::core::{Interaction, Transaction}, interaction_model::core::Interaction,
tlv::TLVWriter, transport::{exchange::Exchange, packet::Packet},
transport::packet::Packet,
}; };
pub struct DataModel<'a, T> { // TODO: For now...
pub acl_mgr: &'a RefCell<AclMgr>, static SUBS_ID: AtomicU32 = AtomicU32::new(1);
pub node: &'a Node<'a>,
pub handler: T,
}
impl<'a, T> DataModel<'a, T> { pub struct DataModel<T>(T);
pub const fn new(acl_mgr: &'a RefCell<AclMgr>, node: &'a Node<'a>, handler: T) -> Self {
Self { impl<T> DataModel<T> {
acl_mgr, pub fn new(handler: T) -> Self {
node, Self(handler)
handler,
}
} }
pub fn handle( pub async fn handle<'r, 'p>(
&mut self, &self,
interaction: Interaction, exchange: &'r mut Exchange<'_>,
tx: &mut Packet, rx: &'r mut Packet<'p>,
transaction: &mut Transaction, tx: &'r mut Packet<'p>,
) -> Result<bool, Error> rx_status: &'r mut Packet<'p>,
) -> Result<(), Error>
where where
T: Handler, T: DataModelHandler,
{ {
let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); let timeout = Interaction::timeout(exchange, rx, tx).await?;
let mut tw = TLVWriter::new(tx.get_writebuf()?);
match interaction { let mut interaction = alloc!(Interaction::new(
Interaction::Read(req) => { exchange,
let mut resume_path = None; rx,
tx,
rx_status,
|| SUBS_ID.fetch_add(1, Ordering::SeqCst),
timeout,
)?);
for item in self.node.read(&req, &accessor) { #[cfg(feature = "alloc")]
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? let interaction = &mut *interaction;
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path) #[cfg(not(feature = "alloc"))]
} let interaction = &mut interaction;
Interaction::Write(req) => {
for item in self.node.write(&req, &accessor) {
AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?;
}
req.complete(tx, transaction)
}
Interaction::Invoke(req) => {
for item in self.node.invoke(&req, &accessor) {
CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?;
}
req.complete(tx, transaction)
}
Interaction::Subscribe(req) => {
let mut resume_path = None;
for item in self.node.subscribing_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::Timed(_) => Ok(false),
Interaction::ResumeRead(req) => {
let mut resume_path = None;
for item in self.node.resume_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::ResumeSubscribe(req) => {
let mut resume_path = None;
for item in self.node.resume_subscribing_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
}
}
#[cfg(feature = "nightly")] #[cfg(feature = "nightly")]
pub async fn handle_async<'p>( let metadata = self.0.lock().await;
&mut self,
interaction: Interaction<'_>,
tx: &'p mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error>
where
T: super::objects::asynch::AsyncHandler,
{
let accessor = Accessor::for_session(transaction.session(), self.acl_mgr);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
#[cfg(not(feature = "nightly"))]
let metadata = self.0.lock();
if interaction.start().await? {
match interaction { match interaction {
Interaction::Read(req) => { Interaction::Read {
let mut resume_path = None; req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
for item in self.node.read(&req, &accessor) { 'outer: for item in metadata.node().read(req, None, &accessor) {
if let Some(path) = while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?)
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? .await?
{ {
resume_path = Some(path); if !driver.send_chunk(req).await? {
break; break 'outer;
}
} }
} }
req.complete(tx, transaction, resume_path) driver.complete(req).await?;
}
Interaction::Write(req) => {
for item in self.node.write(&req, &accessor) {
AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?;
} }
Interaction::Write {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
req.complete(tx, transaction) for item in metadata.node().write(req, &accessor) {
} AttrDataEncoder::handle_write(&item, &self.0, &mut driver.writer()?)
Interaction::Invoke(req) => {
for item in self.node.invoke(&req, &accessor) {
CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw)
.await?; .await?;
} }
req.complete(tx, transaction) driver.complete(req).await?;
} }
Interaction::Subscribe(req) => { Interaction::Invoke {
let mut resume_path = None; req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
for item in self.node.subscribing_read(&req, &accessor) { for item in metadata.node().invoke(req, &accessor) {
if let Some(path) = let (mut tw, exchange) = driver.writer_exchange()?;
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
CmdDataEncoder::handle(&item, &self.0, &mut tw, exchange).await?;
}
driver.complete(req).await?;
}
Interaction::Subscribe {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
'outer: for item in metadata.node().subscribing_read(req, None, &accessor) {
while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?)
.await?
{ {
resume_path = Some(path); if !driver.send_chunk(req).await? {
break; break 'outer;
}
} }
} }
req.complete(tx, transaction, resume_path) driver.complete(req).await?;
} }
Interaction::Timed(_) => Ok(false),
Interaction::ResumeRead(req) => {
let mut resume_path = None;
for item in self.node.resume_read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
} }
} }
req.complete(tx, transaction, resume_path) Ok(())
}
Interaction::ResumeSubscribe(req) => {
let mut resume_path = None;
for item in self.node.resume_subscribing_read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
}
}
}
pub trait DataHandler {
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error>;
}
impl<T> DataHandler for &mut T
where
T: DataHandler,
{
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error> {
(**self).handle(interaction, tx, transaction)
}
}
impl<'a, T> DataHandler for DataModel<'a, T>
where
T: Handler,
{
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error> {
DataModel::handle(self, interaction, tx, transaction)
}
}
#[cfg(feature = "nightly")]
pub mod asynch {
use crate::{
data_model::objects::asynch::AsyncHandler,
error::Error,
interaction_model::core::{Interaction, Transaction},
transport::packet::Packet,
};
use super::DataModel;
pub trait AsyncDataHandler {
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error>;
}
impl<T> AsyncDataHandler for &mut T
where
T: AsyncDataHandler,
{
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error> {
(**self).handle(interaction, tx, transaction).await
}
}
impl<'a, T> AsyncDataHandler for DataModel<'a, T>
where
T: AsyncHandler,
{
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error> {
DataModel::handle_async(self, interaction, tx, transaction).await
}
} }
} }

View file

@ -15,11 +15,13 @@
* limitations under the License. * limitations under the License.
*/ */
use core::cell::Cell;
use crate::utils::rand::Rand; use crate::utils::rand::Rand;
pub struct Dataver { pub struct Dataver {
ver: u32, ver: Cell<u32>,
changed: bool, changed: Cell<bool>,
} }
impl Dataver { impl Dataver {
@ -28,25 +30,25 @@ impl Dataver {
rand(&mut buf); rand(&mut buf);
Self { Self {
ver: u32::from_be_bytes(buf), ver: Cell::new(u32::from_be_bytes(buf)),
changed: false, changed: Cell::new(false),
} }
} }
pub fn get(&self) -> u32 { pub fn get(&self) -> u32 {
self.ver self.ver.get()
} }
pub fn changed(&mut self) -> u32 { pub fn changed(&self) -> u32 {
(self.ver, _) = self.ver.overflowing_add(1); self.ver.set(self.ver.get().overflowing_add(1).0);
self.changed = true; self.changed.set(true);
self.get() self.get()
} }
pub fn consume_change<T>(&mut self, change: T) -> Option<T> { pub fn consume_change<T>(&self, change: T) -> Option<T> {
if self.changed { if self.changed.get() {
self.changed = false; self.changed.set(false);
Some(change) Some(change)
} else { } else {
None None

View file

@ -19,12 +19,12 @@ use core::fmt::{Debug, Formatter};
use core::marker::PhantomData; use core::marker::PhantomData;
use core::ops::{Deref, DerefMut}; use core::ops::{Deref, DerefMut};
use crate::interaction_model::core::{IMStatusCode, Transaction}; use crate::interaction_model::core::IMStatusCode;
use crate::interaction_model::messages::ib::{ use crate::interaction_model::messages::ib::{
AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag,
}; };
use crate::interaction_model::messages::GenericPath;
use crate::tlv::UtfStr; use crate::tlv::UtfStr;
use crate::transport::exchange::Exchange;
use crate::{ use crate::{
error::{Error, ErrorCode}, error::{Error, ErrorCode},
interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag},
@ -32,7 +32,7 @@ use crate::{
}; };
use log::error; use log::error;
use super::{AttrDetails, CmdDetails, Handler}; use super::{AttrDetails, CmdDetails, DataModelHandler};
// TODO: Should this return an IMStatusCode Error? But if yes, the higher layer // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer
// may have already started encoding the 'success' headers, we might not want to manage // may have already started encoding the 'success' headers, we might not want to manage
@ -124,47 +124,75 @@ pub struct AttrDataEncoder<'a, 'b, 'c> {
} }
impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> {
pub fn handle_read<T: Handler>( pub async fn handle_read<T: DataModelHandler>(
item: Result<AttrDetails, AttrStatus>, item: &Result<AttrDetails<'_>, AttrStatus>,
handler: &T, handler: &T,
tw: &mut TLVWriter, tw: &mut TLVWriter<'_, '_>,
) -> Result<Option<GenericPath>, Error> { ) -> Result<bool, Error> {
let status = match item { let status = match item {
Ok(attr) => { Ok(attr) => {
let encoder = AttrDataEncoder::new(&attr, tw); let encoder = AttrDataEncoder::new(attr, tw);
match handler.read(&attr, encoder) { let result = {
#[cfg(not(feature = "nightly"))]
{
handler.read(attr, encoder)
}
#[cfg(feature = "nightly")]
{
handler.read(&attr, encoder).await
}
};
match result {
Ok(()) => None, Ok(()) => None,
Err(e) => { Err(e) => {
if e.code() == ErrorCode::NoSpace { if e.code() == ErrorCode::NoSpace {
return Ok(Some(attr.path().to_gp())); return Ok(false);
} else { } else {
attr.status(e.into())? attr.status(e.into())?
} }
} }
} }
} }
Err(status) => Some(status), Err(status) => Some(status.clone()),
}; };
if let Some(status) = status { if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
} }
Ok(None) Ok(true)
} }
pub fn handle_write<T: Handler>( pub async fn handle_write<T: DataModelHandler>(
item: Result<(AttrDetails, TLVElement), AttrStatus>, item: &Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>,
handler: &mut T, handler: &T,
tw: &mut TLVWriter, tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let status = match item { let status = match item {
Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) { Ok((attr, data)) => {
let result = {
#[cfg(not(feature = "nightly"))]
{
handler.write(attr, AttrData::new(attr.dataver, data))
}
#[cfg(feature = "nightly")]
{
handler
.write(&attr, AttrData::new(attr.dataver, &data))
.await
}
};
match result {
Ok(()) => attr.status(IMStatusCode::Success)?, Ok(()) => attr.status(IMStatusCode::Success)?,
Err(error) => attr.status(error.into())?, Err(error) => attr.status(error.into())?,
}, }
Err(status) => Some(status), }
Err(status) => Some(status.clone()),
}; };
if let Some(status) = status { if let Some(status) = status {
@ -174,61 +202,6 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> {
Ok(()) Ok(())
} }
#[cfg(feature = "nightly")]
pub async fn handle_read_async<T: super::asynch::AsyncHandler>(
item: Result<AttrDetails<'_>, AttrStatus>,
handler: &T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<Option<GenericPath>, Error> {
let status = match item {
Ok(attr) => {
let encoder = AttrDataEncoder::new(&attr, tw);
match handler.read(&attr, encoder).await {
Ok(()) => None,
Err(e) => {
if e.code() == ErrorCode::NoSpace {
return Ok(Some(attr.path().to_gp()));
} else {
attr.status(e.into())?
}
}
}
}
Err(status) => Some(status),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(None)
}
#[cfg(feature = "nightly")]
pub async fn handle_write_async<T: super::asynch::AsyncHandler>(
item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>,
handler: &mut T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> {
let status = match item {
Ok((attr, data)) => match handler
.write(&attr, AttrData::new(attr.dataver, &data))
.await
{
Ok(()) => attr.status(IMStatusCode::Success)?,
Err(error) => attr.status(error.into())?,
},
Err(status) => Some(status),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self { pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self {
Self { Self {
dataver_filter: attr.dataver, dataver_filter: attr.dataver,
@ -365,18 +338,30 @@ pub struct CmdDataEncoder<'a, 'b, 'c> {
} }
impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> {
pub fn handle<T: Handler>( pub async fn handle<T: DataModelHandler>(
item: Result<(CmdDetails, TLVElement), CmdStatus>, item: &Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>,
handler: &mut T, handler: &T,
transaction: &mut Transaction, tw: &mut TLVWriter<'_, '_>,
tw: &mut TLVWriter, exchange: &Exchange<'_>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let status = match item { let status = match item {
Ok((cmd, data)) => { Ok((cmd, data)) => {
let mut tracker = CmdDataTracker::new(); let mut tracker = CmdDataTracker::new();
let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); let encoder = CmdDataEncoder::new(cmd, &mut tracker, tw);
match handler.invoke(transaction, &cmd, &data, encoder) { let result = {
#[cfg(not(feature = "nightly"))]
{
handler.invoke(exchange, cmd, data, encoder)
}
#[cfg(feature = "nightly")]
{
handler.invoke(exchange, &cmd, &data, encoder).await
}
};
match result {
Ok(()) => cmd.success(&tracker), Ok(()) => cmd.success(&tracker),
Err(error) => { Err(error) => {
error!("Error invoking command: {}", error); error!("Error invoking command: {}", error);
@ -386,7 +371,7 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> {
} }
Err(status) => { Err(status) => {
error!("Error invoking command: {:?}", status); error!("Error invoking command: {:?}", status);
Some(status) Some(status.clone())
} }
}; };
@ -397,33 +382,6 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> {
Ok(()) Ok(())
} }
#[cfg(feature = "nightly")]
pub async fn handle_async<T: super::asynch::AsyncHandler>(
item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>,
handler: &mut T,
transaction: &mut Transaction<'_, '_>,
tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> {
let status = match item {
Ok((cmd, data)) => {
let mut tracker = CmdDataTracker::new();
let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw);
match handler.invoke(transaction, &cmd, &data, encoder).await {
Ok(()) => cmd.success(&tracker),
Err(error) => cmd.status(error.into()),
}
}
Err(status) => Some(status),
};
if let Some(status) = status {
InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new( pub fn new(
cmd: &CmdDetails, cmd: &CmdDetails,
tracker: &'a mut CmdDataTracker, tracker: &'a mut CmdDataTracker,

View file

@ -17,12 +17,25 @@
use crate::{ use crate::{
error::{Error, ErrorCode}, error::{Error, ErrorCode},
interaction_model::core::Transaction,
tlv::TLVElement, tlv::TLVElement,
transport::exchange::Exchange,
}; };
use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails};
#[cfg(feature = "nightly")]
pub use asynch::*;
#[cfg(not(feature = "nightly"))]
pub trait DataModelHandler: super::Metadata + Handler {}
#[cfg(not(feature = "nightly"))]
impl<T> DataModelHandler for T where T: super::Metadata + Handler {}
#[cfg(feature = "nightly")]
pub trait DataModelHandler: super::asynch::AsyncMetadata + asynch::AsyncHandler {}
#[cfg(feature = "nightly")]
impl<T> DataModelHandler for T where T: super::asynch::AsyncMetadata + asynch::AsyncHandler {}
pub trait ChangeNotifier<T> { pub trait ChangeNotifier<T> {
fn consume_change(&mut self) -> Option<T>; fn consume_change(&mut self) -> Option<T>;
} }
@ -30,13 +43,13 @@ pub trait ChangeNotifier<T> {
pub trait Handler { pub trait Handler {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>;
fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { fn write(&self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into()) Err(ErrorCode::AttributeNotFound.into())
} }
fn invoke( fn invoke(
&mut self, &self,
_transaction: &mut Transaction, _exchange: &Exchange,
_cmd: &CmdDetails, _cmd: &CmdDetails,
_data: &TLVElement, _data: &TLVElement,
_encoder: CmdDataEncoder, _encoder: CmdDataEncoder,
@ -45,6 +58,29 @@ pub trait Handler {
} }
} }
impl<T> Handler for &T
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
(**self).read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
(**self).write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder)
}
}
impl<T> Handler for &mut T impl<T> Handler for &mut T
where where
T: Handler, T: Handler,
@ -53,25 +89,52 @@ where
(**self).read(attr, encoder) (**self).read(attr, encoder)
} }
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
(**self).write(attr, data) (**self).write(attr, data)
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
(**self).invoke(transaction, cmd, data, encoder) (**self).invoke(exchange, cmd, data, encoder)
} }
} }
pub trait NonBlockingHandler: Handler {} pub trait NonBlockingHandler: Handler {}
impl<T> NonBlockingHandler for &T where T: NonBlockingHandler {}
impl<T> NonBlockingHandler for &mut T where T: NonBlockingHandler {} impl<T> NonBlockingHandler for &mut T where T: NonBlockingHandler {}
impl<M, H> Handler for (M, H)
where
H: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
self.1.read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
self.1.write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
self.1.invoke(exchange, cmd, data, encoder)
}
}
impl<M, H> NonBlockingHandler for (M, H) where H: NonBlockingHandler {}
pub struct EmptyHandler; pub struct EmptyHandler;
impl EmptyHandler { impl EmptyHandler {
@ -140,7 +203,7 @@ where
} }
} }
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id {
self.handler.write(attr, data) self.handler.write(attr, data)
} else { } else {
@ -149,16 +212,16 @@ where
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(transaction, cmd, data, encoder) self.handler.invoke(exchange, cmd, data, encoder)
} else { } else {
self.next.invoke(transaction, cmd, data, encoder) self.next.invoke(exchange, cmd, data, encoder)
} }
} }
} }
@ -184,6 +247,35 @@ where
} }
} }
/// Wrap your `NonBlockingHandler` or `AsyncHandler` implementation in this struct
/// to get your code compilable with and without the `nightly` feature
pub struct HandlerCompat<T>(pub T);
impl<T> Handler for HandlerCompat<T>
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
self.0.read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
self.0.write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
self.0.invoke(exchange, cmd, data, encoder)
}
}
impl<T> NonBlockingHandler for HandlerCompat<T> where T: NonBlockingHandler {}
#[allow(unused_macros)] #[allow(unused_macros)]
#[macro_export] #[macro_export]
macro_rules! handler_chain_type { macro_rules! handler_chain_type {
@ -203,15 +295,15 @@ macro_rules! handler_chain_type {
} }
#[cfg(feature = "nightly")] #[cfg(feature = "nightly")]
pub mod asynch { mod asynch {
use crate::{ use crate::{
data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails},
error::{Error, ErrorCode}, error::{Error, ErrorCode},
interaction_model::core::Transaction,
tlv::TLVElement, tlv::TLVElement,
transport::exchange::Exchange,
}; };
use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler}; use super::{ChainedHandler, EmptyHandler, Handler, HandlerCompat, NonBlockingHandler};
pub trait AsyncHandler { pub trait AsyncHandler {
async fn read<'a>( async fn read<'a>(
@ -221,7 +313,7 @@ pub mod asynch {
) -> Result<(), Error>; ) -> Result<(), Error>;
async fn write<'a>( async fn write<'a>(
&'a mut self, &'a self,
_attr: &'a AttrDetails<'_>, _attr: &'a AttrDetails<'_>,
_data: AttrData<'a>, _data: AttrData<'a>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -229,8 +321,8 @@ pub mod asynch {
} }
async fn invoke<'a>( async fn invoke<'a>(
&'a mut self, &'a self,
_transaction: &'a mut Transaction<'_, '_>, _exchange: &'a Exchange<'_>,
_cmd: &'a CmdDetails<'_>, _cmd: &'a CmdDetails<'_>,
_data: &'a TLVElement<'_>, _data: &'a TLVElement<'_>,
_encoder: CmdDataEncoder<'a, '_, '_>, _encoder: CmdDataEncoder<'a, '_, '_>,
@ -252,7 +344,7 @@ pub mod asynch {
} }
async fn write<'a>( async fn write<'a>(
&'a mut self, &'a self,
attr: &'a AttrDetails<'_>, attr: &'a AttrDetails<'_>,
data: AttrData<'a>, data: AttrData<'a>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -260,19 +352,79 @@ pub mod asynch {
} }
async fn invoke<'a>( async fn invoke<'a>(
&'a mut self, &'a self,
transaction: &'a mut Transaction<'_, '_>, exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>, cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>, data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>, encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> { ) -> Result<(), Error> {
(**self).invoke(transaction, cmd, data, encoder).await (**self).invoke(exchange, cmd, data, encoder).await
} }
} }
pub struct Asyncify<T>(pub T); impl<T> AsyncHandler for &T
where
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).read(attr, encoder).await
}
impl<T> AsyncHandler for Asyncify<T> async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
(**self).write(attr, data).await
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder).await
}
}
impl<M, H> AsyncHandler for (M, H)
where
H: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
self.1.read(attr, encoder).await
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
self.1.write(attr, data).await
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
self.1.invoke(exchange, cmd, data, encoder).await
}
}
impl<T> AsyncHandler for HandlerCompat<T>
where where
T: NonBlockingHandler, T: NonBlockingHandler,
{ {
@ -285,21 +437,21 @@ pub mod asynch {
} }
async fn write<'a>( async fn write<'a>(
&'a mut self, &'a self,
attr: &'a AttrDetails<'_>, attr: &'a AttrDetails<'_>,
data: AttrData<'a>, data: AttrData<'a>,
) -> Result<(), Error> { ) -> Result<(), Error> {
Handler::write(&mut self.0, attr, data) Handler::write(&self.0, attr, data)
} }
async fn invoke<'a>( async fn invoke<'a>(
&'a mut self, &'a self,
transaction: &'a mut Transaction<'_, '_>, exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>, cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>, data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>, encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> { ) -> Result<(), Error> {
Handler::invoke(&mut self.0, transaction, cmd, data, encoder) Handler::invoke(&self.0, exchange, cmd, data, encoder)
} }
} }
@ -332,7 +484,7 @@ pub mod asynch {
} }
async fn write<'a>( async fn write<'a>(
&'a mut self, &'a self,
attr: &'a AttrDetails<'_>, attr: &'a AttrDetails<'_>,
data: AttrData<'a>, data: AttrData<'a>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -345,16 +497,16 @@ pub mod asynch {
} }
async fn invoke<'a>( async fn invoke<'a>(
&'a mut self, &'a self,
transaction: &'a mut Transaction<'_, '_>, exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>, cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>, data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>, encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> { ) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(transaction, cmd, data, encoder).await self.handler.invoke(exchange, cmd, data, encoder).await
} else { } else {
self.next.invoke(transaction, cmd, data, encoder).await self.next.invoke(exchange, cmd, data, encoder).await
} }
} }
} }

View file

@ -0,0 +1,178 @@
use crate::data_model::objects::Node;
#[cfg(feature = "nightly")]
pub use asynch::*;
use super::HandlerCompat;
pub trait MetadataGuard {
fn node(&self) -> Node<'_>;
}
impl<T> MetadataGuard for &T
where
T: MetadataGuard,
{
fn node(&self) -> Node<'_> {
(**self).node()
}
}
impl<T> MetadataGuard for &mut T
where
T: MetadataGuard,
{
fn node(&self) -> Node<'_> {
(**self).node()
}
}
pub trait Metadata {
type MetadataGuard<'a>: MetadataGuard
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_>;
}
impl<T> Metadata for &T
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock()
}
}
impl<T> Metadata for &mut T
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock()
}
}
impl<'a> MetadataGuard for Node<'a> {
fn node(&self) -> Node<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<'a> Metadata for Node<'a> {
type MetadataGuard<'g> = Node<'g> where Self: 'g;
fn lock(&self) -> Self::MetadataGuard<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<M, H> Metadata for (M, H)
where
M: Metadata,
{
type MetadataGuard<'a> = M::MetadataGuard<'a>
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
impl<T> Metadata for HandlerCompat<T>
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a>
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
#[cfg(feature = "nightly")]
pub mod asynch {
use crate::data_model::objects::{HandlerCompat, Node};
use super::{Metadata, MetadataGuard};
pub trait AsyncMetadata {
type MetadataGuard<'a>: MetadataGuard
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_>;
}
impl<T> AsyncMetadata for &T
where
T: AsyncMetadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock().await
}
}
impl<T> AsyncMetadata for &mut T
where
T: AsyncMetadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock().await
}
}
impl<'a> AsyncMetadata for Node<'a> {
type MetadataGuard<'g> = Node<'g> where Self: 'g;
async fn lock(&self) -> Self::MetadataGuard<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<M, H> AsyncMetadata for (M, H)
where
M: AsyncMetadata,
{
type MetadataGuard<'a> = M::MetadataGuard<'a>
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock().await
}
}
impl<T> AsyncMetadata for HandlerCompat<T>
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a>
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
}

View file

@ -41,6 +41,9 @@ pub use handler::*;
mod dataver; mod dataver;
pub use dataver::*; pub use dataver::*;
mod metadata;
pub use metadata::*;
pub type EndptId = u16; pub type EndptId = u16;
pub type ClusterId = u32; pub type ClusterId = u32;
pub type AttrId = u16; pub type AttrId = u16;

View file

@ -17,9 +17,10 @@
use crate::{ use crate::{
acl::Accessor, acl::Accessor,
alloc,
data_model::objects::Endpoint, data_model::objects::Endpoint,
interaction_model::{ interaction_model::{
core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq}, core::IMStatusCode,
messages::{ messages::{
ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter},
msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, msg::{InvReq, ReadReq, SubscribeReq, WriteReq},
@ -27,7 +28,7 @@ use crate::{
}, },
}, },
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{TLVArray, TLVArrayIter, TLVElement}, tlv::{TLVArray, TLVElement},
}; };
use core::{ use core::{
fmt, fmt,
@ -57,41 +58,6 @@ where
} }
} }
pub trait Iterable {
type Item;
type Iterator<'a>: Iterator<Item = Self::Item>
where
Self: 'a;
fn iter(&self) -> Self::Iterator<'_>;
}
impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> {
type Item = DataVersionFilter;
type Iterator<'i> = WildcardIter<TLVArrayIter<'i, DataVersionFilter>, DataVersionFilter> where Self: 'i;
fn iter(&self) -> Self::Iterator<'_> {
if let Some(filters) = self {
WildcardIter::Wildcard(filters.iter())
} else {
WildcardIter::None
}
}
}
impl<'a> Iterable for &'a [DataVersionFilter] {
type Item = DataVersionFilter;
type Iterator<'i> = core::iter::Cloned<core::slice::Iter<'i, DataVersionFilter>> where Self: 'i;
fn iter(&self) -> Self::Iterator<'_> {
let slice: &[DataVersionFilter] = self;
slice.iter().cloned()
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Node<'a> { pub struct Node<'a> {
pub id: u16, pub id: u16,
@ -102,6 +68,7 @@ impl<'a> Node<'a> {
pub fn read<'s, 'm>( pub fn read<'s, 'm>(
&'s self, &'s self,
req: &'m ReadReq, req: &'m ReadReq,
from: Option<GenericPath>,
accessor: &'m Accessor<'m>, accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm ) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where where
@ -114,30 +81,14 @@ impl<'a> Node<'a> {
req.dataver_filters.as_ref(), req.dataver_filters.as_ref(),
req.fabric_filtered, req.fabric_filtered,
accessor, accessor,
None, from,
)
}
pub fn resume_read<'s, 'm>(
&'s self,
req: &'m ResumeReadReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.paths.iter().cloned(),
req.filters.as_slice(),
req.fabric_filtered,
accessor,
Some(req.resume_path.clone()),
) )
} }
pub fn subscribing_read<'s, 'm>( pub fn subscribing_read<'s, 'm>(
&'s self, &'s self,
req: &'m SubscribeReq, req: &'m SubscribeReq,
from: Option<GenericPath>,
accessor: &'m Accessor<'m>, accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm ) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where where
@ -150,31 +101,14 @@ impl<'a> Node<'a> {
req.dataver_filters.as_ref(), req.dataver_filters.as_ref(),
req.fabric_filtered, req.fabric_filtered,
accessor, accessor,
None, from,
) )
} }
pub fn resume_subscribing_read<'s, 'm>( fn read_attr_requests<'s, 'm, P>(
&'s self,
req: &'m ResumeSubscribeReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.paths.iter().cloned(),
req.filters.as_slice(),
req.fabric_filtered,
accessor,
Some(req.resume_path.clone().unwrap()),
)
}
fn read_attr_requests<'s, 'm, P, D>(
&'s self, &'s self,
attr_requests: P, attr_requests: P,
dataver_filters: D, dataver_filters: Option<&'m TLVArray<DataVersionFilter>>,
fabric_filtered: bool, fabric_filtered: bool,
accessor: &'m Accessor<'m>, accessor: &'m Accessor<'m>,
from: Option<GenericPath>, from: Option<GenericPath>,
@ -182,11 +116,9 @@ impl<'a> Node<'a> {
where where
's: 'm, 's: 'm,
P: Iterator<Item = AttrPath> + 'm, P: Iterator<Item = AttrPath> + 'm,
D: Iterable<Item = DataVersionFilter> + Clone + 'm,
{ {
attr_requests.flat_map(move |path| { alloc!(attr_requests.flat_map(move |path| {
if path.to_gp().is_wildcard() { if path.to_gp().is_wildcard() {
let dataver_filters = dataver_filters.clone();
let from = from.clone(); let from = from.clone();
let iter = self let iter = self
@ -204,10 +136,14 @@ impl<'a> Node<'a> {
.is_ok() .is_ok()
}) })
.map(move |(ep, cl, attr)| { .map(move |(ep, cl, attr)| {
let dataver = dataver_filters.iter().find_map(|filter| { let dataver = if let Some(dataver_filters) = dataver_filters {
dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep.id && filter.path.cluster == cl.id) (filter.path.endpoint == ep.id && filter.path.cluster == cl.id)
.then_some(filter.data_ver) .then_some(filter.data_ver)
}); })
} else {
None
};
Ok(AttrDetails { Ok(AttrDetails {
node: self, node: self,
@ -230,10 +166,14 @@ impl<'a> Node<'a> {
let result = match self.check_attribute(accessor, ep, cl, attr, false) { let result = match self.check_attribute(accessor, ep, cl, attr, false) {
Ok(()) => { Ok(()) => {
let dataver = dataver_filters.iter().find_map(|filter| { let dataver = if let Some(dataver_filters) = dataver_filters {
dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep && filter.path.cluster == cl) (filter.path.endpoint == ep && filter.path.cluster == cl)
.then_some(filter.data_ver) .then_some(filter.data_ver)
}); })
} else {
None
};
Ok(AttrDetails { Ok(AttrDetails {
node: self, node: self,
@ -252,7 +192,7 @@ impl<'a> Node<'a> {
WildcardIter::Single(once(result)) WildcardIter::Single(once(result))
} }
}) }))
} }
pub fn write<'m>( pub fn write<'m>(
@ -260,7 +200,7 @@ impl<'a> Node<'a> {
req: &'m WriteReq, req: &'m WriteReq,
accessor: &'m Accessor<'m>, accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(AttrDetails, TLVElement<'m>), AttrStatus>> + 'm { ) -> impl Iterator<Item = Result<(AttrDetails, TLVElement<'m>), AttrStatus>> + 'm {
req.write_requests.iter().flat_map(move |attr_data| { alloc!(req.write_requests.iter().flat_map(move |attr_data| {
if attr_data.path.cluster.is_none() { if attr_data.path.cluster.is_none() {
WildcardIter::Single(once(Err(AttrStatus::new( WildcardIter::Single(once(Err(AttrStatus::new(
&attr_data.path.to_gp(), &attr_data.path.to_gp(),
@ -332,7 +272,7 @@ impl<'a> Node<'a> {
WildcardIter::Single(once(result)) WildcardIter::Single(once(result))
} }
}) }))
} }
pub fn invoke<'m>( pub fn invoke<'m>(
@ -340,7 +280,8 @@ impl<'a> Node<'a> {
req: &'m InvReq, req: &'m InvReq,
accessor: &'m Accessor<'m>, accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(CmdDetails, TLVElement<'m>), CmdStatus>> + 'm { ) -> impl Iterator<Item = Result<(CmdDetails, TLVElement<'m>), CmdStatus>> + 'm {
req.inv_requests alloc!(req
.inv_requests
.iter() .iter()
.flat_map(|inv_requests| inv_requests.iter()) .flat_map(|inv_requests| inv_requests.iter())
.flat_map(move |cmd_data| { .flat_map(move |cmd_data| {
@ -393,7 +334,7 @@ impl<'a> Node<'a> {
WildcardIter::Single(once(result)) WildcardIter::Single(once(result))
} }
}) }))
} }
fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool {

View file

@ -46,7 +46,7 @@ pub const CLUSTERS: [Cluster<'static>; 7] = [
access_control::CLUSTER, access_control::CLUSTER,
]; ];
pub fn endpoint(id: EndptId) -> Endpoint<'static> { pub const fn endpoint(id: EndptId) -> Endpoint<'static> {
Endpoint { Endpoint {
id, id,
device_type: super::device_types::DEV_TYPE_ROOT_NODE, device_type: super::device_types::DEV_TYPE_ROOT_NODE,

View file

@ -19,11 +19,11 @@ use core::cell::RefCell;
use core::convert::TryInto; use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::interaction_model::core::Transaction;
use crate::mdns::Mdns; use crate::mdns::Mdns;
use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::pake::PaseMgr;
use crate::secure_channel::spake2p::VerifierData; use crate::secure_channel::spake2p::VerifierData;
use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement};
use crate::transport::exchange::Exchange;
use crate::utils::rand::Rand; use crate::utils::rand::Rand;
use crate::{attribute_enum, cmd_enter}; use crate::{attribute_enum, cmd_enter};
use crate::{command_enum, error::*}; use crate::{command_enum, error::*};
@ -84,8 +84,8 @@ pub const CLUSTER: Cluster<'static> = Cluster {
], ],
commands: &[ commands: &[
Commands::OpenCommWindow as _, Commands::OpenCommWindow as _,
Commands::OpenBasicCommWindow as _, // Commands::OpenBasicCommWindow as _,
Commands::RevokeComm as _, // Commands::RevokeComm as _,
], ],
}; };
@ -133,7 +133,7 @@ impl<'a> AdminCommCluster<'a> {
} }
pub fn invoke( pub fn invoke(
&mut self, &self,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
_encoder: CmdDataEncoder, _encoder: CmdDataEncoder,
@ -148,7 +148,7 @@ impl<'a> AdminCommCluster<'a> {
Ok(()) Ok(())
} }
fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> { fn handle_command_opencomm_win(&self, data: &TLVElement) -> Result<(), Error> {
cmd_enter!("Open Commissioning Window"); cmd_enter!("Open Commissioning Window");
let req = OpenCommWindowReq::from_tlv(data)?; let req = OpenCommWindowReq::from_tlv(data)?;
let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0);
@ -166,8 +166,8 @@ impl<'a> Handler for AdminCommCluster<'a> {
} }
fn invoke( fn invoke(
&mut self, &self,
_transaction: &mut Transaction, _exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,

View file

@ -20,8 +20,8 @@ use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::data_model::sdm::failsafe::FailSafe; use crate::data_model::sdm::failsafe::FailSafe;
use crate::interaction_model::core::Transaction;
use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr};
use crate::transport::exchange::Exchange;
use crate::utils::rand::Rand; use crate::utils::rand::Rand;
use crate::{attribute_enum, cmd_enter}; use crate::{attribute_enum, cmd_enter};
use crate::{command_enum, error::*}; use crate::{command_enum, error::*};
@ -171,19 +171,19 @@ impl<'a> GenCommCluster<'a> {
} }
pub fn invoke( pub fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
match cmd.cmd_id.try_into()? { match cmd.cmd_id.try_into()? {
Commands::ArmFailsafe => self.handle_command_armfailsafe(transaction, data, encoder)?, Commands::ArmFailsafe => self.handle_command_armfailsafe(exchange, data, encoder)?,
Commands::SetRegulatoryConfig => { Commands::SetRegulatoryConfig => {
self.handle_command_setregulatoryconfig(transaction, data, encoder)? self.handle_command_setregulatoryconfig(exchange, data, encoder)?
} }
Commands::CommissioningComplete => { Commands::CommissioningComplete => {
self.handle_command_commissioningcomplete(transaction, encoder)?; self.handle_command_commissioningcomplete(exchange, encoder)?;
} }
} }
@ -193,8 +193,8 @@ impl<'a> GenCommCluster<'a> {
} }
fn handle_command_armfailsafe( fn handle_command_armfailsafe(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -207,7 +207,7 @@ impl<'a> GenCommCluster<'a> {
.borrow_mut() .borrow_mut()
.arm( .arm(
p.expiry_len, p.expiry_len,
transaction.session().get_session_mode().clone(), exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?,
) )
.is_err() .is_err()
{ {
@ -225,13 +225,12 @@ impl<'a> GenCommCluster<'a> {
.with_command(RespCommands::ArmFailsafeResp as _)? .with_command(RespCommands::ArmFailsafeResp as _)?
.set(cmd_data)?; .set(cmd_data)?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_setregulatoryconfig( fn handle_command_setregulatoryconfig(
&mut self, &self,
transaction: &mut Transaction, _exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -252,20 +251,22 @@ impl<'a> GenCommCluster<'a> {
.with_command(RespCommands::SetRegulatoryConfigResp as _)? .with_command(RespCommands::SetRegulatoryConfigResp as _)?
.set(cmd_data)?; .set(cmd_data)?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_commissioningcomplete( fn handle_command_commissioningcomplete(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
cmd_enter!("Commissioning Complete"); cmd_enter!("Commissioning Complete");
let mut status: u8 = CommissioningError::Ok as u8; let mut status: u8 = CommissioningError::Ok as u8;
// Has to be a Case Session // Has to be a Case Session
if transaction.session().get_local_fabric_idx().is_none() { if exchange
.with_session(|sess| Ok(sess.get_local_fabric_idx()))?
.is_none()
{
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
} }
@ -274,7 +275,7 @@ impl<'a> GenCommCluster<'a> {
if self if self
.failsafe .failsafe
.borrow_mut() .borrow_mut()
.disarm(transaction.session().get_session_mode().clone()) .disarm(exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?)
.is_err() .is_err()
{ {
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
@ -289,7 +290,6 @@ impl<'a> GenCommCluster<'a> {
.with_command(RespCommands::CommissioningCompleteResp as _)? .with_command(RespCommands::CommissioningCompleteResp as _)?
.set(cmd_data)?; .set(cmd_data)?;
transaction.complete();
Ok(()) Ok(())
} }
} }
@ -300,13 +300,13 @@ impl<'a> Handler for GenCommCluster<'a> {
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
GenCommCluster::invoke(self, transaction, cmd, data, encoder) GenCommCluster::invoke(self, exchange, cmd, data, encoder)
} }
} }

View file

@ -24,9 +24,9 @@ use crate::crypto::{self, KeyPair};
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::data_model::sdm::dev_att; use crate::data_model::sdm::dev_att;
use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS};
use crate::interaction_model::core::Transaction;
use crate::mdns::Mdns; use crate::mdns::Mdns;
use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr};
use crate::transport::exchange::Exchange;
use crate::transport::session::SessionMode; use crate::transport::session::SessionMode;
use crate::utils::epoch::Epoch; use crate::utils::epoch::Epoch;
use crate::utils::rand::Rand; use crate::utils::rand::Rand;
@ -289,26 +289,26 @@ impl<'a> NocCluster<'a> {
} }
pub fn invoke( pub fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
match cmd.cmd_id.try_into()? { match cmd.cmd_id.try_into()? {
Commands::AddNOC => self.handle_command_addnoc(transaction, data, encoder)?, Commands::AddNOC => self.handle_command_addnoc(exchange, data, encoder)?,
Commands::CSRReq => self.handle_command_csrrequest(transaction, data, encoder)?, Commands::CSRReq => self.handle_command_csrrequest(exchange, data, encoder)?,
Commands::AddTrustedRootCert => { Commands::AddTrustedRootCert => {
self.handle_command_addtrustedrootcert(transaction, data)? self.handle_command_addtrustedrootcert(exchange, data)?
} }
Commands::AttReq => self.handle_command_attrequest(transaction, data, encoder)?, Commands::AttReq => self.handle_command_attrequest(exchange, data, encoder)?,
Commands::CertChainReq => { Commands::CertChainReq => {
self.handle_command_certchainrequest(transaction, data, encoder)? self.handle_command_certchainrequest(exchange, data, encoder)?
} }
Commands::UpdateFabricLabel => { Commands::UpdateFabricLabel => {
self.handle_command_updatefablabel(transaction, data, encoder)?; self.handle_command_updatefablabel(exchange, data, encoder)?;
} }
Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, Commands::RemoveFabric => self.handle_command_rmfabric(exchange, data, encoder)?,
} }
self.data_ver.changed(); self.data_ver.changed();
@ -323,13 +323,12 @@ impl<'a> NocCluster<'a> {
} }
fn _handle_command_addnoc( fn _handle_command_addnoc(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
) -> Result<u8, NocError> { ) -> Result<u8, NocError> {
let noc_data = transaction let noc_data = exchange
.session_mut() .with_session_mut(|sess| Ok(sess.take_noc_data()))?
.take_noc_data()
.ok_or(NocStatus::MissingCsr)?; .ok_or(NocStatus::MissingCsr)?;
if !self if !self
@ -411,15 +410,16 @@ impl<'a> NocCluster<'a> {
} }
fn handle_command_updatefablabel( fn handle_command_updatefablabel(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
cmd_enter!("Update Fabric Label"); cmd_enter!("Update Fabric Label");
let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?;
let (result, fab_idx) = let (result, fab_idx) = if let SessionMode::Case(c) =
if let SessionMode::Case(c) = transaction.session().get_session_mode() { exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?
{
if self if self
.fabric_mgr .fabric_mgr
.borrow_mut() .borrow_mut()
@ -440,13 +440,12 @@ impl<'a> NocCluster<'a> {
Self::create_nocresponse(encoder, result, fab_idx, "")?; Self::create_nocresponse(encoder, result, fab_idx, "")?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_rmfabric( fn handle_command_rmfabric(
&mut self, &self,
transaction: &mut Transaction, _exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -459,7 +458,7 @@ impl<'a> NocCluster<'a> {
.is_ok() .is_ok()
{ {
let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx);
transaction.terminate(); // TODO: transaction.terminate();
Ok(()) Ok(())
} else { } else {
Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "")
@ -467,28 +466,27 @@ impl<'a> NocCluster<'a> {
} }
fn handle_command_addnoc( fn handle_command_addnoc(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
cmd_enter!("AddNOC"); cmd_enter!("AddNOC");
let (status, fab_idx) = match self._handle_command_addnoc(transaction, data) { let (status, fab_idx) = match self._handle_command_addnoc(exchange, data) {
Ok(fab_idx) => (NocStatus::Ok, fab_idx), Ok(fab_idx) => (NocStatus::Ok, fab_idx),
Err(NocError::Status(status)) => (status, 0), Err(NocError::Status(status)) => (status, 0),
Err(NocError::Error(error)) => Err(error)?, Err(NocError::Error(error)) => Err(error)?,
}; };
Self::create_nocresponse(encoder, status, fab_idx, "")?; Self::create_nocresponse(encoder, status, fab_idx, "")?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_attrequest( fn handle_command_attrequest(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -498,7 +496,10 @@ impl<'a> NocCluster<'a> {
info!("Received Attestation Nonce:{:?}", req.str); info!("Received Attestation Nonce:{:?}", req.str);
let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES];
attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); exchange.with_session(|sess| {
attest_challenge.copy_from_slice(sess.get_att_challenge());
Ok(())
})?;
let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?;
@ -522,13 +523,12 @@ impl<'a> NocCluster<'a> {
writer.complete()?; writer.complete()?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_certchainrequest( fn handle_command_certchainrequest(
&mut self, &self,
transaction: &mut Transaction, _exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -549,13 +549,12 @@ impl<'a> NocCluster<'a> {
.with_command(RespCommands::CertChainResp as _)? .with_command(RespCommands::CertChainResp as _)?
.set(cmd_data)?; .set(cmd_data)?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_csrrequest( fn handle_command_csrrequest(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -570,7 +569,10 @@ impl<'a> NocCluster<'a> {
let noc_keypair = KeyPair::new(self.rand)?; let noc_keypair = KeyPair::new(self.rand)?;
let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES];
attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); exchange.with_session(|sess| {
attest_challenge.copy_from_slice(sess.get_att_challenge());
Ok(())
})?;
let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; let mut writer = encoder.with_command(RespCommands::CSRResp as _)?;
@ -591,15 +593,17 @@ impl<'a> NocCluster<'a> {
let noc_data = NocData::new(noc_keypair); let noc_data = NocData::new(noc_keypair);
// Store this in the session data instead of cluster data, so it gets cleared // Store this in the session data instead of cluster data, so it gets cleared
// if the session goes away for some reason // if the session goes away for some reason
transaction.session_mut().set_noc_data(noc_data); exchange.with_session_mut(|sess| {
sess.set_noc_data(noc_data);
Ok(())
})?;
transaction.complete();
Ok(()) Ok(())
} }
fn handle_command_addtrustedrootcert( fn handle_command_addtrustedrootcert(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
data: &TLVElement, data: &TLVElement,
) -> Result<(), Error> { ) -> Result<(), Error> {
cmd_enter!("AddTrustedRootCert"); cmd_enter!("AddTrustedRootCert");
@ -608,25 +612,26 @@ impl<'a> NocCluster<'a> {
} }
// This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary
match transaction.session().get_session_mode() { match exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? {
SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now,
SessionMode::Pase => { SessionMode::Pase => {
let noc_data = transaction exchange.with_session_mut(|sess| {
.session_mut() let noc_data = sess.get_noc_data().ok_or(ErrorCode::NoSession)?;
.get_noc_data()
.ok_or(ErrorCode::NoSession)?;
let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?;
info!("Received Trusted Cert:{:x?}", req.str); info!("Received Trusted Cert:{:x?}", req.str);
noc_data.root_ca = noc_data.root_ca = heapless::Vec::from_slice(req.str.0)
heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; .map_err(|_| ErrorCode::BufferTooSmall)?;
Ok(())
})?;
// TODO // TODO
} }
_ => (), _ => (),
} }
transaction.complete();
Ok(()) Ok(())
} }
} }
@ -637,13 +642,13 @@ impl<'a> Handler for NocCluster<'a> {
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
NocCluster::invoke(self, transaction, cmd, data, encoder) NocCluster::invoke(self, exchange, cmd, data, encoder)
} }
} }

View file

@ -132,7 +132,7 @@ impl<'a> AccessControlCluster<'a> {
} }
} }
pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
match attr.attr_id.try_into()? { match attr.attr_id.try_into()? {
Attributes::Acl(_) => { Attributes::Acl(_) => {
attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| {
@ -151,7 +151,7 @@ impl<'a> AccessControlCluster<'a> {
/// This takes care of 4 things, add item, edit item, delete item, delete list. /// This takes care of 4 things, add item, edit item, delete item, delete list.
/// Care about fabric-scoped behaviour is taken /// Care about fabric-scoped behaviour is taken
fn write_acl_attr( fn write_acl_attr(
&mut self, &self,
op: &ListOperation, op: &ListOperation,
data: &TLVElement, data: &TLVElement,
fab_idx: u8, fab_idx: u8,
@ -185,7 +185,7 @@ impl<'a> Handler for AccessControlCluster<'a> {
AccessControlCluster::read(self, attr, encoder) AccessControlCluster::read(self, attr, encoder)
} }
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
AccessControlCluster::write(self, attr, data) AccessControlCluster::write(self, attr, data)
} }
} }
@ -220,7 +220,7 @@ mod tests {
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let acl_mgr = RefCell::new(AclMgr::new()); let acl_mgr = RefCell::new(AclMgr::new());
let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
@ -258,7 +258,7 @@ mod tests {
for i in &verifier { for i in &verifier {
acl_mgr.borrow_mut().add(i.clone()).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
@ -295,7 +295,7 @@ mod tests {
for i in &input { for i in &input {
acl_mgr.borrow_mut().add(i.clone()).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
// data is don't-care actually // data is don't-care actually
let data = TLVElement::new(TagType::Anonymous, ElementType::True); let data = TLVElement::new(TagType::Anonymous, ElementType::True);

File diff suppressed because it is too large Load diff

View file

@ -69,6 +69,7 @@
//! Start off exploring by going to the [Matter] object. //! Start off exploring by going to the [Matter] object.
#![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] #![cfg_attr(feature = "nightly", feature(async_fn_in_trait))]
#![cfg_attr(feature = "nightly", feature(impl_trait_projections))]
#![cfg_attr(feature = "nightly", allow(incomplete_features))] #![cfg_attr(feature = "nightly", allow(incomplete_features))]
pub mod acl; pub mod acl;
@ -90,3 +91,22 @@ pub mod transport;
pub mod utils; pub mod utils;
pub use crate::core::*; pub use crate::core::*;
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
#[macro_export]
macro_rules! alloc {
($val:expr) => {
alloc::boxed::Box::new($val)
};
}
#[cfg(not(feature = "alloc"))]
#[macro_export]
macro_rules! alloc {
($val:expr) => {
$val
};
}

View file

@ -20,30 +20,25 @@ use core::cell::RefCell;
use log::{error, trace}; use log::{error, trace};
use crate::{ use crate::{
alloc,
cert::Cert, cert::Cert,
crypto::{self, KeyPair, Sha256}, crypto::{self, KeyPair, Sha256},
error::{Error, ErrorCode}, error::{Error, ErrorCode},
fabric::{Fabric, FabricMgr}, fabric::{Fabric, FabricMgr},
secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode, PROTO_ID_SECURE_CHANNEL},
secure_channel::common::{self, OpCode}, secure_channel::common::{complete_with_status, SCStatusCodes},
tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType},
transport::{ transport::{
exchange::Exchange,
network::Address, network::Address,
proto_ctx::ProtoCtx, packet::Packet,
session::{CaseDetails, CloneData, NocCatIds, SessionMode}, session::{CaseDetails, CloneData, NocCatIds, SessionMode},
}, },
utils::{rand::Rand, writebuf::WriteBuf}, utils::{rand::Rand, writebuf::WriteBuf},
}; };
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State {
Sigma1Rx,
Sigma3Rx,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CaseSession { struct CaseSession {
state: State,
peer_sessid: u16, peer_sessid: u16,
local_sessid: u16, local_sessid: u16,
tt_hash: Sha256, tt_hash: Sha256,
@ -54,11 +49,11 @@ pub struct CaseSession {
} }
impl CaseSession { impl CaseSession {
pub fn new(peer_sessid: u16, local_sessid: u16) -> Result<Self, Error> { #[inline(always)]
pub fn new() -> Result<Self, Error> {
Ok(Self { Ok(Self {
state: State::Sigma1Rx, peer_sessid: 0,
peer_sessid, local_sessid: 0,
local_sessid,
tt_hash: Sha256::new()?, tt_hash: Sha256::new()?,
shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES], shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES],
our_pub_key: [0; crypto::EC_POINT_LEN_BYTES], our_pub_key: [0; crypto::EC_POINT_LEN_BYTES],
@ -79,39 +74,50 @@ impl<'a> Case<'a> {
Self { fabric_mgr, rand } Self { fabric_mgr, rand }
} }
pub fn casesigma3_handler( pub async fn handle(
&mut self, &mut self,
ctx: &mut ProtoCtx, exchange: &mut Exchange<'_>,
) -> Result<(bool, Option<CloneData>), Error> { rx: &mut Packet<'_>,
let mut case_session = ctx tx: &mut Packet<'_>,
.exch_ctx ) -> Result<(), Error> {
.exch let mut session = alloc!(CaseSession::new()?);
.take_case_session()
.ok_or(ErrorCode::InvalidState)?; self.handle_casesigma1(exchange, rx, tx, &mut session)
if case_session.state != State::Sigma1Rx { .await?;
Err(ErrorCode::Invalid)?; self.handle_casesigma3(exchange, rx, tx, &mut session).await
} }
case_session.state = State::Sigma3Rx;
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_casesigma3(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
case_session: &mut CaseSession,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
let fabric_mgr = self.fabric_mgr.borrow(); let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( drop(fabric_mgr);
ctx.tx, complete_with_status(
exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok((true, None)); return Ok(());
} }
// Safe to unwrap here // Safe to unwrap here
let fabric = fabric.unwrap(); let fabric = fabric.unwrap();
let root = get_root_node_struct(ctx.rx.as_slice())?; let root = get_root_node_struct(rx.as_slice())?;
let encrypted = root.find_tag(1)?.slice()?; let encrypted = root.find_tag(1)?.slice()?;
let mut decrypted: [u8; 800] = [0; 800]; let mut decrypted = alloc!([0; 800]);
if encrypted.len() > decrypted.len() { if encrypted.len() > decrypted.len() {
error!("Data too large"); error!("Data too large");
Err(ErrorCode::NoSpace)?; Err(ErrorCode::NoSpace)?;
@ -119,22 +125,29 @@ impl<'a> Case<'a> {
let decrypted = &mut decrypted[..encrypted.len()]; let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted); decrypted.copy_from_slice(encrypted);
let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), &case_session, decrypted)?; let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
let decrypted = &decrypted[..len]; let decrypted = &decrypted[..len];
let root = get_root_node_struct(decrypted)?; let root = get_root_node_struct(decrypted)?;
let d = Sigma3Decrypt::from_tlv(&root)?; let d = Sigma3Decrypt::from_tlv(&root)?;
let initiator_noc = Cert::new(d.initiator_noc.0)?; let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
let mut initiator_icac = None; let mut initiator_icac = None;
if let Some(icac) = d.initiator_icac { if let Some(icac) = d.initiator_icac {
initiator_icac = Some(Cert::new(icac.0)?); initiator_icac = Some(alloc!(Cert::new(icac.0)?));
} }
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) {
#[cfg(feature = "alloc")]
let initiator_icac_mut = initiator_icac.as_deref();
#[cfg(not(feature = "alloc"))]
let initiator_icac_mut = initiator_icac.as_ref();
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
error!("Certificate Chain doesn't match: {}", e); error!("Certificate Chain doesn't match: {}", e);
common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
ctx.exch_ctx.exch.close(); .await?;
return Ok((true, None)); return Ok(());
} }
if Case::validate_sigma3_sign( if Case::validate_sigma3_sign(
@ -142,39 +155,52 @@ impl<'a> Case<'a> {
d.initiator_icac.map(|a| a.0), d.initiator_icac.map(|a| a.0),
&initiator_noc, &initiator_noc,
d.signature.0, d.signature.0,
&case_session, case_session,
) )
.is_err() .is_err()
{ {
error!("Sigma3 Signature doesn't match"); error!("Sigma3 Signature doesn't match");
common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
ctx.exch_ctx.exch.close(); .await?;
return Ok((true, None)); return Ok(());
} }
// Only now do we add this message to the TT Hash // Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default(); let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids); initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(ctx.rx.as_slice())?; case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data( let clone_data = Case::get_session_clone_data(
fabric.ipk.op_key(), fabric.ipk.op_key(),
fabric.get_node_id(), fabric.get_node_id(),
initiator_noc.get_node_id()?, initiator_noc.get_node_id()?,
ctx.exch_ctx.sess.get_peer_addr(), exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
&case_session, case_session,
&peer_catids, &peer_catids,
)?; )?;
common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; // TODO: Handle NoSpace
ctx.exch_ctx.exch.clear_data(); exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
ctx.exch_ctx.exch.close();
Ok((true, Some(clone_data))) complete_with_status(
exchange,
tx,
SCStatusCodes::SessionEstablishmentSuccess,
None,
)
.await
} }
pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> { #[allow(clippy::await_holding_refcell_ref)]
ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); async fn handle_casesigma1(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
case_session: &mut CaseSession,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma1 as _)?;
let rx_buf = ctx.rx.as_slice(); let rx_buf = rx.as_slice();
let root = get_root_node_struct(rx_buf)?; let root = get_root_node_struct(rx_buf)?;
let r = Sigma1Req::from_tlv(&root)?; let r = Sigma1Req::from_tlv(&root)?;
@ -184,17 +210,20 @@ impl<'a> Case<'a> {
.match_dest_id(r.initiator_random.0, r.dest_id.0); .match_dest_id(r.initiator_random.0, r.dest_id.0);
if local_fabric_idx.is_err() { if local_fabric_idx.is_err() {
error!("Fabric Index mismatch"); error!("Fabric Index mismatch");
common::create_sc_status_report( complete_with_status(
ctx.tx, exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok(true);
return Ok(());
} }
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?; case_session.peer_sessid = r.initiator_sessid;
case_session.local_sessid = local_sessid;
case_session.tt_hash.update(rx_buf)?; case_session.tt_hash.update(rx_buf)?;
case_session.local_fabric_idx = local_fabric_idx?; case_session.local_fabric_idx = local_fabric_idx?;
if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES {
@ -225,52 +254,71 @@ impl<'a> Case<'a> {
// Derive the Encrypted Part // Derive the Encrypted Part
const MAX_ENCRYPTED_SIZE: usize = 800; const MAX_ENCRYPTED_SIZE: usize = 800;
let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
let encrypted_len = { let encrypted_len = {
let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
let fabric_mgr = self.fabric_mgr.borrow(); let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( drop(fabric_mgr);
ctx.tx, complete_with_status(
exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok(true); return Ok(());
} }
#[cfg(feature = "alloc")]
let signature_mut = &mut *signature;
#[cfg(not(feature = "alloc"))]
let signature_mut = &mut signature;
let sign_len = Case::get_sigma2_sign( let sign_len = Case::get_sigma2_sign(
fabric.unwrap(), fabric.unwrap(),
&case_session.our_pub_key, &case_session.our_pub_key,
&case_session.peer_pub_key, &case_session.peer_pub_key,
&mut signature, signature_mut,
)?; )?;
let signature = &signature[..sign_len]; let signature = &signature[..sign_len];
#[cfg(feature = "alloc")]
let encrypted_mut = &mut *encrypted;
#[cfg(not(feature = "alloc"))]
let encrypted_mut = &mut encrypted;
Case::get_sigma2_encryption( Case::get_sigma2_encryption(
fabric.unwrap(), fabric.unwrap(),
self.rand, self.rand,
&our_random, &our_random,
&mut case_session, case_session,
signature, signature,
&mut encrypted, encrypted_mut,
)? )?
}; };
let encrypted = &encrypted[0..encrypted_len]; let encrypted = &encrypted[0..encrypted_len];
// Generate our Response Body // Generate our Response Body
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::CASESigma2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str8(TagType::Context(1), &our_random)?; tw.str8(TagType::Context(1), &our_random)?;
tw.u16(TagType::Context(2), local_sessid)?; tw.u16(TagType::Context(2), local_sessid)?;
tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
tw.str16(TagType::Context(4), encrypted)?; tw.str16(TagType::Context(4), encrypted)?;
tw.end_container()?; tw.end_container()?;
case_session.tt_hash.update(ctx.tx.as_mut_slice())?;
ctx.exch_ctx.exch.set_case_session(case_session); case_session.tt_hash.update(tx.as_mut_slice())?;
Ok(true)
exchange.exchange(tx, rx).await
} }
fn get_session_clone_data( fn get_session_clone_data(
@ -334,7 +382,7 @@ impl<'a> Case<'a> {
Ok(()) Ok(())
} }
fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option<Cert>) -> Result<(), Error> { fn validate_certs(fabric: &Fabric, noc: &Cert, icac: Option<&Cert>) -> Result<(), Error> {
let mut verifier = noc.verify_chain_start(); let mut verifier = noc.verify_chain_start();
if fabric.get_fabric_id() != noc.get_fabric_id()? { if fabric.get_fabric_id() != noc.get_fabric_id()? {

View file

@ -17,7 +17,10 @@
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use crate::{error::Error, transport::packet::Packet}; use crate::{
error::Error,
transport::{exchange::Exchange, packet::Packet},
};
use super::status_report::{create_status_report, GeneralCode}; use super::status_report::{create_status_report, GeneralCode};
@ -51,6 +54,17 @@ pub enum SCStatusCodes {
SessionNotFound = 5, SessionNotFound = 5,
} }
pub async fn complete_with_status(
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
status_code: SCStatusCodes,
proto_data: Option<&[u8]>,
) -> Result<(), Error> {
create_sc_status_report(tx, status_code, proto_data)?;
exchange.send_complete(tx).await
}
pub fn create_sc_status_report( pub fn create_sc_status_report(
proto_tx: &mut Packet, proto_tx: &mut Packet,
status_code: SCStatusCodes, status_code: SCStatusCodes,

View file

@ -15,18 +15,19 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{borrow::Borrow, cell::RefCell}; use core::borrow::Borrow;
use core::cell::RefCell;
use log::error;
use crate::{ use crate::{
error::*, error::*,
fabric::FabricMgr, fabric::FabricMgr,
mdns::Mdns, mdns::Mdns,
secure_channel::common::*, secure_channel::{common::*, pake::Pake},
tlv, transport::{exchange::Exchange, packet::Packet},
transport::{proto_ctx::ProtoCtx, session::CloneData},
utils::{epoch::Epoch, rand::Rand}, utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info};
use super::{case::Case, pake::PaseMgr}; use super::{case::Case, pake::PaseMgr};
@ -34,9 +35,10 @@ use super::{case::Case, pake::PaseMgr};
*/ */
pub struct SecureChannel<'a> { pub struct SecureChannel<'a> {
case: Case<'a>,
pase: &'a RefCell<PaseMgr>, pase: &'a RefCell<PaseMgr>,
fabric: &'a RefCell<FabricMgr>,
mdns: &'a dyn Mdns, mdns: &'a dyn Mdns,
rand: Rand,
} }
impl<'a> SecureChannel<'a> { impl<'a> SecureChannel<'a> {
@ -66,45 +68,34 @@ impl<'a> SecureChannel<'a> {
rand: Rand, rand: Rand,
) -> Self { ) -> Self {
Self { Self {
case: Case::new(fabric, rand), fabric,
pase, pase,
mdns, mdns,
rand,
} }
} }
pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option<CloneData>), Error> { pub async fn handle(
let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?; &self,
exchange: &mut Exchange<'_>,
ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); rx: &mut Packet<'_>,
info!("Received Opcode: {:?}", proto_opcode); tx: &mut Packet<'_>,
info!("Received Data:"); ) -> Result<(), Error> {
tlv::print_tlv_list(ctx.rx.as_slice()); match rx.get_proto_opcode()? {
let (reply, clone_data) = match proto_opcode { OpCode::PBKDFParamRequest => {
OpCode::MRPStandAloneAck => Ok((false, None)), Pake::new(self.pase)
OpCode::PBKDFParamRequest => self .handle(exchange, rx, tx, self.mdns)
.pase .await
.borrow_mut() }
.pbkdfparamreq_handler(ctx) OpCode::CASESigma1 => {
.map(|reply| (reply, None)), Case::new(self.fabric, self.rand)
OpCode::PASEPake1 => self .handle(exchange, rx, tx)
.pase .await
.borrow_mut() }
.pasepake1_handler(ctx) proto_opcode => {
.map(|reply| (reply, None)), error!("OpCode not handled: {:?}", proto_opcode);
OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns),
OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)),
OpCode::CASESigma3 => self.case.casesigma3_handler(ctx),
_ => {
error!("OpCode Not Handled: {:?}", proto_opcode);
Err(ErrorCode::InvalidOpcode.into()) Err(ErrorCode::InvalidOpcode.into())
} }
}?;
if reply {
info!("Sending response");
tlv::print_tlv_list(ctx.tx.as_mut_slice());
} }
Ok((reply, clone_data))
} }
} }

View file

@ -15,36 +15,35 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{fmt::Write, time::Duration}; use core::{cell::RefCell, fmt::Write, time::Duration};
use super::{ use super::{
common::{create_sc_status_report, SCStatusCodes}, common::{SCStatusCodes, PROTO_ID_SECURE_CHANNEL},
spake2p::{Spake2P, VerifierData}, spake2p::{Spake2P, VerifierData},
}; };
use crate::{ use crate::{
crypto, alloc, crypto,
error::{Error, ErrorCode}, error::{Error, ErrorCode},
mdns::{Mdns, ServiceMode}, mdns::{Mdns, ServiceMode},
secure_channel::common::OpCode, secure_channel::common::{complete_with_status, OpCode},
tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV},
transport::{ transport::{
exchange::ExchangeCtx, exchange::{Exchange, ExchangeId},
network::Address, packet::Packet,
proto_ctx::ProtoCtx,
session::{CloneData, SessionMode}, session::{CloneData, SessionMode},
}, },
utils::{epoch::Epoch, rand::Rand}, utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info}; use log::{error, info};
#[allow(clippy::large_enum_variant)] struct PaseSession {
enum PaseMgrState { mdns_service_name: heapless::String<16>,
Enabled(Pake, heapless::String<16>), verifier: VerifierData,
Disabled,
} }
pub struct PaseMgr { pub struct PaseMgr {
state: PaseMgrState, session: Option<PaseSession>,
timeout: Option<Timeout>,
epoch: Epoch, epoch: Epoch,
rand: Rand, rand: Rand,
} }
@ -53,14 +52,15 @@ impl PaseMgr {
#[inline(always)] #[inline(always)]
pub fn new(epoch: Epoch, rand: Rand) -> Self { pub fn new(epoch: Epoch, rand: Rand) -> Self {
Self { Self {
state: PaseMgrState::Disabled, session: None,
timeout: None,
epoch, epoch,
rand, rand,
} }
} }
pub fn is_pase_session_enabled(&self) -> bool { pub fn is_pase_session_enabled(&self) -> bool {
matches!(&self.state, PaseMgrState::Enabled(_, _)) self.session.is_some()
} }
pub fn enable_pase_session( pub fn enable_pase_session(
@ -80,62 +80,24 @@ impl PaseMgr {
&mdns_service_name, &mdns_service_name,
ServiceMode::Commissionable(discriminator), ServiceMode::Commissionable(discriminator),
)?; )?;
self.state = PaseMgrState::Enabled(
Pake::new(verifier, self.epoch, self.rand), self.session = Some(PaseSession {
mdns_service_name, mdns_service_name,
); verifier,
});
Ok(()) Ok(())
} }
pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> {
if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { if let Some(session) = self.session.as_ref() {
mdns.remove(mdns_service_name)?; mdns.remove(&session.mdns_service_name)?;
} }
self.state = PaseMgrState::Disabled; self.session = None;
Ok(()) Ok(())
} }
/// If the PASE Session is enabled, execute the closure,
/// if not enabled, generate SC Status Report
fn if_enabled<F, T>(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<Option<T>, Error>
where
F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<T, Error>,
{
if let PaseMgrState::Enabled(pake, _) = &mut self.state {
let data = f(pake, ctx)?;
Ok(Some(data))
} else {
error!("PASE Not enabled");
create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?;
Ok(None)
}
}
pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?;
Ok(true)
}
pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?;
Ok(true)
}
pub fn pasepake3_handler(
&mut self,
ctx: &mut ProtoCtx,
mdns: &dyn Mdns,
) -> Result<(bool, Option<CloneData>), Error> {
let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?;
self.disable_pase_session(mdns)?;
Ok((true, clone_data.flatten()))
}
} }
// This file basically deals with the handlers for the PASE secure channel protocol // This file basically deals with the handlers for the PASE secure channel protocol
@ -147,96 +109,65 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60);
const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys";
struct SessionData { struct Timeout {
start_time: Duration, start_time: Duration,
exch_id: u16, exch_id: ExchangeId,
peer_addr: Address,
spake2p: Spake2P,
} }
impl SessionData { impl Timeout {
fn is_sess_expired(&self, epoch: Epoch) -> Result<bool, Error> { fn new(exchange: &Exchange, epoch: Epoch) -> Self {
Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS)
}
}
#[allow(clippy::large_enum_variant)]
enum PakeState {
Idle,
InProgress(SessionData),
}
impl PakeState {
const fn new() -> Self {
Self::Idle
}
fn take(&mut self) -> Result<SessionData, Error> {
let new = core::mem::replace(self, PakeState::Idle);
if let PakeState::InProgress(s) = new {
Ok(s)
} else {
Err(ErrorCode::InvalidSignature.into())
}
}
fn is_idle(&self) -> bool {
core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle)
}
fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result<SessionData, Error> {
let sd = self.take()?;
if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() {
Err(ErrorCode::InvalidState.into())
} else {
Ok(sd)
}
}
fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) {
*self = PakeState::InProgress(SessionData {
start_time: epoch(),
spake2p,
exch_id: exch_ctx.exch.get_id(),
peer_addr: exch_ctx.sess.get_peer_addr(),
});
}
fn set_sess_data(&mut self, sd: SessionData) {
*self = PakeState::InProgress(sd);
}
}
impl Default for PakeState {
fn default() -> Self {
Self::new()
}
}
struct Pake {
verifier: VerifierData,
state: PakeState,
epoch: Epoch,
rand: Rand,
}
impl Pake {
pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self {
// TODO: Can any PBKDF2 calculation be pre-computed here
Self { Self {
verifier, start_time: epoch(),
state: PakeState::new(), exch_id: exchange.id().clone(),
epoch,
rand,
} }
} }
fn is_sess_expired(&self, epoch: Epoch) -> bool {
epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS
}
}
pub struct Pake<'a> {
pase: &'a RefCell<PaseMgr>,
}
impl<'a> Pake<'a> {
pub const fn new(pase: &'a RefCell<PaseMgr>) -> Self {
// TODO: Can any PBKDF2 calculation be pre-computed here
Self { pase }
}
pub async fn handle(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
mdns: &dyn Mdns,
) -> Result<(), Error> {
let mut spake2p = alloc!(Spake2P::new());
self.handle_pbkdfparamrequest(exchange, rx, tx, &mut spake2p)
.await?;
self.handle_pasepake1(exchange, rx, tx, &mut spake2p)
.await?;
self.handle_pasepake3(exchange, rx, tx, mdns, &mut spake2p)
.await
}
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<Option<CloneData>, Error> { async fn handle_pasepake3(
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; &mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
mdns: &dyn Mdns,
spake2p: &mut Spake2P,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::PASEPake3 as _)?;
self.update_timeout(exchange, tx, true).await?;
let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let (status_code, ke) = sd.spake2p.handle_cA(cA); let (status_code, ke) = spake2p.handle_cA(cA);
let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
// Get the keys // Get the keys
@ -246,7 +177,7 @@ impl Pake {
.map_err(|_x| ErrorCode::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// Create a session // Create a session
let data = sd.spake2p.get_app_data(); let data = spake2p.get_app_data();
let peer_sessid: u16 = (data & 0xffff) as u16; let peer_sessid: u16 = (data & 0xffff) as u16;
let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16;
let mut clone_data = CloneData::new( let mut clone_data = CloneData::new(
@ -254,7 +185,7 @@ impl Pake {
0, 0,
peer_sessid, peer_sessid,
local_sessid, local_sessid,
ctx.exch_ctx.sess.get_peer_addr(), exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
SessionMode::Pase, SessionMode::Pase,
); );
clone_data.dec_key.copy_from_slice(&session_keys[0..16]); clone_data.dec_key.copy_from_slice(&session_keys[0..16]);
@ -269,48 +200,70 @@ impl Pake {
None None
}; };
create_sc_status_report(ctx.tx, status_code, None)?; if let Some(clone_data) = clone_data {
ctx.exch_ctx.exch.close(); // TODO: Handle NoSpace
Ok(clone_data) exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
self.pase.borrow_mut().disable_pase_session(mdns)?;
}
complete_with_status(exchange, tx, status_code, None).await?;
Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { #[allow(clippy::await_holding_refcell_ref)]
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; async fn handle_pasepake1(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
spake2p: &mut Spake2P,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
self.update_timeout(exchange, tx, false).await?;
let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let mut pB: [u8; 65] = [0; 65]; let mut pB: [u8; 65] = [0; 65];
let mut cB: [u8; 32] = [0; 32]; let mut cB: [u8; 32] = [0; 32];
sd.spake2p.start_verifier(&self.verifier)?; spake2p.start_verifier(&session.verifier)?;
sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?; spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); // Generate response
tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PASEPake2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let resp = Pake1Resp { let resp = Pake1Resp {
pb: OctetStr(&pB), pb: OctetStr(&pB),
cb: OctetStr(&cB), cb: OctetStr(&cB),
}; };
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
self.state.set_sess_data(sd); drop(pase);
exchange.exchange(tx, rx).await
Ok(())
} }
pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { #[allow(clippy::await_holding_refcell_ref)]
if !self.state.is_idle() { async fn handle_pbkdfparamrequest(
let sd = self.state.take()?; &mut self,
if sd.is_sess_expired(self.epoch)? { exchange: &mut Exchange<'_>,
info!("Previous session expired, clearing it"); rx: &mut Packet<'_>,
self.state = PakeState::Idle; tx: &mut Packet<'_>,
} else { spake2p: &mut Spake2P,
info!("Previous session in-progress, denying new request"); ) -> Result<(), Error> {
// little-endian timeout (here we've hardcoded 500ms) rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; self.update_timeout(exchange, tx, true).await?;
return Ok(());
}
}
let root = tlv::get_root_node(ctx.rx.as_slice())?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let root = tlv::get_root_node(rx.as_slice())?;
let a = PBKDFParamReq::from_tlv(&root)?; let a = PBKDFParamReq::from_tlv(&root)?;
if a.passcode_id != 0 { if a.passcode_id != 0 {
error!("Can't yet handle passcode_id != 0"); error!("Can't yet handle passcode_id != 0");
@ -318,15 +271,18 @@ impl Pake {
} }
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
(self.rand)(&mut our_random); (self.pase.borrow().rand)(&mut our_random);
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
let mut spake2p = Spake2P::new();
spake2p.set_app_data(spake2p_data); spake2p.set_app_data(spake2p_data);
// Generate response // Generate response
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let mut resp = PBKDFParamResp { let mut resp = PBKDFParamResp {
init_random: a.initiator_random, init_random: a.initiator_random,
our_random: OctetStr(&our_random), our_random: OctetStr(&our_random),
@ -335,19 +291,77 @@ impl Pake {
}; };
if !a.has_params { if !a.has_params {
let params_resp = PBKDFParamRespParams { let params_resp = PBKDFParamRespParams {
count: self.verifier.count, count: session.verifier.count,
salt: OctetStr(&self.verifier.salt), salt: OctetStr(&session.verifier.salt),
}; };
resp.params = Some(params_resp); resp.params = Some(params_resp);
} }
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
self.state
.make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); drop(pase);
exchange.exchange(tx, rx).await
}
#[allow(clippy::await_holding_refcell_ref)]
async fn update_timeout(
&mut self,
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
new: bool,
) -> Result<(), Error> {
self.check_session(exchange, tx).await?;
let mut pase = self.pase.borrow_mut();
if pase
.timeout
.as_ref()
.map(|sd| sd.is_sess_expired(pase.epoch))
.unwrap_or(false)
{
pase.timeout = None;
}
let status = if let Some(sd) = pase.timeout.as_mut() {
if &sd.exch_id != exchange.id() {
info!("Other PAKE session in progress");
Some(SCStatusCodes::Busy)
} else {
None
}
} else if new {
None
} else {
error!("PAKE session not found or expired");
Some(SCStatusCodes::SessionNotFound)
};
if let Some(status) = status {
drop(pase);
complete_with_status(exchange, tx, status, None).await
} else {
pase.timeout = Some(Timeout::new(exchange, pase.epoch));
Ok(()) Ok(())
} }
}
async fn check_session(
&mut self,
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
) -> Result<(), Error> {
if self.pase.borrow().session.is_none() {
error!("PASE not enabled");
complete_with_status(exchange, tx, SCStatusCodes::InvalidParameter, None).await
} else {
Ok(())
}
}
} }
#[derive(ToTLV)] #[derive(ToTLV)]

View file

@ -15,189 +15,51 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{borrow::Borrow, cell::RefCell};
use crate::{error::ErrorCode, secure_channel::common::OpCode, Matter};
use embassy_futures::select::select;
use embassy_time::{Duration, Timer};
use log::info; use log::info;
use crate::{error::*, CommissioningData, Matter}; use crate::{
error::Error, secure_channel::common::PROTO_ID_SECURE_CHANNEL, transport::packet::Packet,
};
use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; use super::{
use crate::secure_channel::core::SecureChannel; exchange::{
use crate::transport::mrp::ReliableMessage; Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Notification, Role,
use crate::transport::{exchange, network::Address, packet::Packet}; MAX_EXCHANGES,
use super::proto_ctx::ProtoCtx;
use super::session::CloneData;
enum RecvState {
New,
OpenExchange,
AddSession(CloneData),
EvictSession,
EvictSession2(CloneData),
Ack,
}
pub enum RecvAction<'r, 'p> {
Send(Address, &'r [u8]),
Interact(ProtoCtx<'r, 'p>),
}
pub struct RecvCompletion<'r, 'a> {
transport: &'r mut Transport<'a>,
rx: Packet<'r>,
tx: Packet<'r>,
state: RecvState,
}
impl<'r, 'a> RecvCompletion<'r, 'a> {
pub fn next_action(&mut self) -> Result<Option<RecvAction<'_, 'r>>, Error> {
loop {
// Polonius will remove the need for unsafe one day
let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() };
if let Some(action) = this.maybe_next_action()? {
return Ok(action);
}
}
}
fn maybe_next_action(&mut self) -> Result<Option<Option<RecvAction<'_, 'r>>>, Error> {
self.transport.exch_mgr.purge();
self.tx.reset();
let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) {
RecvState::New => {
self.rx.plain_hdr_decode()?;
(RecvState::OpenExchange, None)
}
RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) {
Ok(Some(exch_ctx)) => {
if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx);
let mut secure_channel = SecureChannel::new(self.transport.matter);
let (reply, clone_data) = secure_channel.handle(&mut proto_ctx)?;
let state = if let Some(clone_data) = clone_data {
RecvState::AddSession(clone_data)
} else {
RecvState::Ack
};
if reply {
if proto_ctx.send()? {
(
state,
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(state, None)
}
} else {
(state, None)
}
} else {
let proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx);
(RecvState::Ack, Some(Some(RecvAction::Interact(proto_ctx))))
}
}
Ok(None) => (RecvState::Ack, None),
Err(e) => match e.code() {
ErrorCode::Duplicate => (RecvState::Ack, None),
ErrorCode::NoSpace => (RecvState::EvictSession, None),
_ => Err(e)?,
}, },
}, mrp::ReliableMessage,
RecvState::AddSession(clone_data) => { session::SessionMgr,
match self.transport.exch_mgr.add_session(&clone_data) { };
Ok(_) => (RecvState::Ack, None),
Err(e) => match e.code() {
ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None),
_ => Err(e)?,
},
}
}
RecvState::EvictSession => {
if self.transport.exch_mgr.evict_session(&mut self.tx)? {
(
RecvState::OpenExchange,
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::EvictSession, None)
}
}
RecvState::EvictSession2(clone_data) => {
if self.transport.exch_mgr.evict_session(&mut self.tx)? {
(
RecvState::AddSession(clone_data),
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::EvictSession2(clone_data), None)
}
}
RecvState::Ack => {
if let Some(exch_id) = self.transport.exch_mgr.pending_ack() {
info!("Sending MRP Standalone ACK for exch {}", exch_id);
ReliableMessage::prepare_ack(exch_id, &mut self.tx); #[derive(Debug)]
enum OpCodeDescriptor {
if self.transport.exch_mgr.send(exch_id, &mut self.tx)? { SecureChannel(OpCode),
( InteractionModel(crate::interaction_model::core::OpCode),
RecvState::Ack, Unknown(u8),
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::Ack, None)
}
} else {
(RecvState::Ack, Some(None))
}
}
};
self.state = state;
Ok(next)
}
} }
enum NotifyState {} impl From<u8> for OpCodeDescriptor {
fn from(value: u8) -> Self {
pub enum NotifyAction<'r, 'p> { if let Some(opcode) = num::FromPrimitive::from_u8(value) {
Send(&'r [u8]), Self::SecureChannel(opcode)
Notify(ProtoCtx<'r, 'p>), } else if let Some(opcode) = num::FromPrimitive::from_u8(value) {
} Self::InteractionModel(opcode)
} else {
pub struct NotifyCompletion<'r, 'a> { Self::Unknown(value)
// TODO
_transport: &'r mut Transport<'a>,
_rx: Packet<'r>,
_tx: Packet<'r>,
_state: NotifyState,
}
impl<'r, 'a> NotifyCompletion<'r, 'a> {
pub fn next_action(&mut self) -> Result<Option<NotifyAction<'_, 'r>>, Error> {
loop {
// Polonius will remove the need for unsafe one day
let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() };
if let Some(action) = this.maybe_next_action()? {
return Ok(action);
} }
} }
}
fn maybe_next_action(&mut self) -> Result<Option<Option<NotifyAction<'_, 'r>>>, Error> {
Ok(Some(None)) // TODO: Future
}
} }
pub struct Transport<'a> { pub struct Transport<'a> {
matter: &'a Matter<'a>, matter: &'a Matter<'a>,
exch_mgr: exchange::ExchangeMgr, pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
pub(crate) send_notification: Notification,
pub(crate) persist_notification: Notification,
pub session_mgr: RefCell<SessionMgr>,
} }
impl<'a> Transport<'a> { impl<'a> Transport<'a> {
@ -208,44 +70,358 @@ impl<'a> Transport<'a> {
Self { Self {
matter, matter,
exch_mgr: exchange::ExchangeMgr::new(epoch, rand), exchanges: RefCell::new(heapless::Vec::new()),
send_notification: Notification::new(),
persist_notification: Notification::new(),
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
} }
} }
pub fn matter(&self) -> &Matter<'a> { pub fn matter(&self) -> &'a Matter<'a> {
self.matter self.matter
} }
pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { pub async fn initiate(&self, _fabric_id: u64, _node_id: u64) -> Result<Exchange<'a>, Error> {
info!("Starting Matter transport"); unimplemented!()
}
if self.matter().start_comissioning(dev_comm, buf)? { pub fn process_rx<'r>(
info!("Comissioning started"); &'r self,
construction_notification: &'r Notification,
src_rx: &mut Packet<'_>,
) -> Result<Option<ExchangeCtr<'r>>, Error> {
self.purge()?;
let mut exchanges = self.exchanges.borrow_mut();
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) {
Ok((ctx, new)) => (ctx, new),
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
_ => Err(e)?,
},
};
src_rx.log("Got packet");
if src_rx.proto.is_ack() {
if new {
Err(ErrorCode::Invalid)?;
} else {
let state = &mut ctx.state;
match state {
ExchangeState::ExchangeRecv {
tx_acknowledged, ..
} => {
*tx_acknowledged = true;
}
ExchangeState::CompleteAcknowledge { notification, .. } => {
unsafe { notification.as_ref() }.unwrap().signal(());
ctx.state = ExchangeState::Closed;
}
_ => {
// TODO: Error handling
todo!()
}
}
self.notify_changed();
}
}
if new {
let constructor = ExchangeCtr {
exchange: Exchange {
id: ctx.id.clone(),
transport: self,
notification: Notification::new(),
},
construction_notification,
};
self.notify_changed();
Ok(Some(constructor))
} else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL
&& src_rx.proto.proto_opcode == OpCode::MRPStandAloneAck as u8
{
// Standalone ack, do nothing
Ok(None)
} else {
let state = &mut ctx.state;
match state {
ExchangeState::ExchangeRecv {
rx, notification, ..
} => {
let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?;
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
}
_ => {
// TODO: Error handling
todo!()
}
}
self.notify_changed();
Ok(None)
}
}
pub async fn wait_construction(
&self,
construction_notification: &Notification,
src_rx: &Packet<'_>,
exchange_id: &ExchangeId,
) -> Result<(), Error> {
construction_notification.wait().await;
let mut exchanges = self.exchanges.borrow_mut();
let ctx = Self::get(&mut exchanges, exchange_id).unwrap();
let state = &mut ctx.state;
match state {
ExchangeState::Construction { rx, notification } => {
let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?;
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
}
_ => unreachable!(),
} }
Ok(()) Ok(())
} }
pub fn recv<'r>( pub async fn wait_tx(&self) -> Result<(), Error> {
&'r mut self, select(
addr: Address, self.send_notification.wait(),
rx_buf: &'r mut [u8], Timer::after(Duration::from_millis(100)),
tx_buf: &'r mut [u8], )
) -> RecvCompletion<'r, 'a> { .await;
let mut rx = Packet::new_rx(rx_buf);
let tx = Packet::new_tx(tx_buf);
rx.peer = addr; Ok(())
}
RecvCompletion { pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> {
transport: self, self.purge()?;
rx,
let mut exchanges = self.exchanges.borrow_mut();
let ctx = exchanges.iter_mut().find(|ctx| {
matches!(
&ctx.state,
ExchangeState::Acknowledge { .. }
| ExchangeState::ExchangeSend { .. }
// | ExchangeState::ExchangeRecv {
// tx_acknowledged: false,
// ..
// }
| ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. }
) || ctx.mrp.is_ack_ready(*self.matter.borrow())
});
if let Some(ctx) = ctx {
self.notify_changed();
let state = &mut ctx.state;
let send = match state {
ExchangeState::Acknowledge { notification } => {
ReliableMessage::prepare_ack(ctx.id.id, dest_tx);
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
true
}
ExchangeState::ExchangeSend {
tx, tx,
state: RecvState::New, rx,
notification,
} => {
let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?;
*state = ExchangeState::ExchangeRecv {
_tx: tx,
tx_acknowledged: false,
rx: *rx,
notification: *notification,
};
true
}
// ExchangeState::ExchangeRecv { .. } => {
// // TODO: Re-send the tx package if due
// false
// }
ExchangeState::Complete { tx, notification } => {
let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?;
*state = ExchangeState::CompleteAcknowledge {
_tx: tx as *const _,
notification: *notification,
};
true
}
// ExchangeState::CompleteAcknowledge { .. } => {
// // TODO: Re-send the tx package if due
// false
// }
_ => {
ReliableMessage::prepare_ack(ctx.id.id, dest_tx);
true
}
};
if send {
dest_tx.log("Sending packet");
self.pre_send(ctx, dest_tx)?;
self.notify_changed();
return Ok(true);
} }
} }
pub fn notify(&mut self, _tx: &mut Packet) -> Result<bool, Error> {
Ok(false) Ok(false)
} }
fn purge(&self) -> Result<(), Error> {
loop {
let mut exchanges = self.exchanges.borrow_mut();
if let Some(index) = exchanges.iter_mut().enumerate().find_map(|(index, ctx)| {
matches!(ctx.state, ExchangeState::Closed).then_some(index)
}) {
exchanges.swap_remove(index);
} else {
break;
}
}
Ok(())
}
fn post_recv<'r>(
&self,
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
rx: &mut Packet<'_>,
) -> Result<(&'r mut ExchangeCtx, bool), Error> {
rx.plain_hdr_decode()?;
// Get the session
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr.post_recv(rx)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
// Decrypt the message
session.recv(self.matter.epoch, rx)?;
// Get the exchange
// TODO: Handle out of space
let (exch, new) = Self::register(
exchanges,
ExchangeId::load(rx),
Role::complementary(rx.proto.is_initiator()),
// We create a new exchange, only if the peer is the initiator
rx.proto.is_initiator(),
)?;
// Message Reliability Protocol
exch.mrp.recv(rx, self.matter.epoch)?;
Ok((exch, new))
}
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
tx.proto.exch_id = ctx.id.id;
if ctx.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
ctx.mrp.pre_send(tx)?;
session_mgr.send(sess_index, tx)
}
fn register(
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
id: ExchangeId,
role: Role,
create_new: bool,
) -> Result<(&mut ExchangeCtx, bool), Error> {
let exchange_index = exchanges
.iter_mut()
.enumerate()
.find_map(|(index, exchange)| (exchange.id == id).then_some(index));
if let Some(exchange_index) = exchange_index {
let exchange = &mut exchanges[exchange_index];
if exchange.role == role {
Ok((exchange, false))
} else {
Err(ErrorCode::NoExchange.into())
}
} else if create_new {
info!("Creating new exchange: {:?}", id);
let exchange = ExchangeCtx {
id,
role,
mrp: ReliableMessage::new(),
state: ExchangeState::Active,
};
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?;
Ok((exchanges.iter_mut().next_back().unwrap(), true))
} else {
Err(ErrorCode::NoExchange.into())
}
}
pub(crate) fn get<'r>(
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
id: &ExchangeId,
) -> Option<&'r mut ExchangeCtx> {
exchanges.iter_mut().find(|exchange| exchange.id == *id)
}
pub fn notify_changed(&self) {
if self.matter().is_changed() {
self.persist_notification.signal(());
}
}
pub async fn wait_changed(&self) {
self.persist_notification.wait().await
}
} }

View file

@ -1,625 +1,320 @@
/* use embassy_sync::blocking_mutex::raw::NoopRawMutex;
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use core::fmt; use crate::{
use core::time::Duration; acl::Accessor,
use log::{error, info, trace}; error::{Error, ErrorCode},
use owo_colors::OwoColorize; Matter,
};
use crate::error::{Error, ErrorCode}; use super::{
use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; core::Transport,
use crate::secure_channel; mrp::ReliableMessage,
use crate::secure_channel::case::CaseSession; network::Address,
use crate::utils::epoch::Epoch; packet::Packet,
use crate::utils::rand::Rand; session::{Session, SessionMgr},
};
use heapless::LinearMap; pub const MAX_EXCHANGES: usize = 8;
use super::session::CloneData; pub type Notification = embassy_sync::signal::Signal<NoopRawMutex, ()>;
use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr};
pub struct ExchangeCtx<'a> {
pub exch: &'a mut Exchange,
pub sess: SessionHandle<'a>,
pub epoch: Epoch,
}
impl<'a> ExchangeCtx<'a> {
pub fn send(&mut self, tx: &mut Packet) -> Result<bool, Error> {
self.exch.send(tx, &mut self.sess)
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)]
pub enum Role { pub(crate) enum Role {
#[default] #[default]
Initiator = 0, Initiator = 0,
Responder = 1, Responder = 1,
} }
#[derive(Debug, PartialEq, Default)] impl Role {
enum State { pub fn complementary(is_initiator: bool) -> Self {
/// The exchange is open and active
#[default]
Open,
/// The exchange is closed, but keys are active since retransmissions/acks may be pending
Close,
/// The exchange is terminated, keys are destroyed, no communication can happen
Terminate,
}
// Instead of just doing an Option<>, we create some special handling
// where the commonly used higher layer data store does't have to do a Box
#[derive(Default)]
pub enum DataOption {
CaseSession(CaseSession),
Time(Duration),
SuspendedReadReq(ResumeReadReq),
SubscriptionId(u32),
SuspendedSubscibeReq(ResumeSubscribeReq),
#[default]
None,
}
#[derive(Default)]
pub struct Exchange {
id: u16,
sess_idx: usize,
role: Role,
state: State,
mrp: ReliableMessage,
// Currently I see this primarily used in PASE and CASE. If that is the limited use
// of this, we might move this into a separate data structure, so as not to burden
// all 'exchanges'.
data: DataOption,
}
impl Exchange {
pub fn new(id: u16, sess_idx: usize, role: Role) -> Exchange {
Exchange {
id,
sess_idx,
role,
state: State::Open,
mrp: ReliableMessage::new(),
..Default::default()
}
}
pub fn terminate(&mut self) {
self.data = DataOption::None;
self.state = State::Terminate;
}
pub fn close(&mut self) {
self.data = DataOption::None;
self.state = State::Close;
}
pub fn is_state_open(&self) -> bool {
self.state == State::Open
}
pub fn is_purgeable(&self) -> bool {
// No Users, No pending ACKs/Retrans
self.state == State::Terminate || (self.state == State::Close && self.mrp.is_empty())
}
pub fn get_id(&self) -> u16 {
self.id
}
pub fn get_role(&self) -> Role {
self.role
}
pub fn clear_data(&mut self) {
self.data = DataOption::None;
}
pub fn set_case_session(&mut self, session: CaseSession) {
self.data = DataOption::CaseSession(session);
}
pub fn get_case_session(&mut self) -> Option<&mut CaseSession> {
if let DataOption::CaseSession(session) = &mut self.data {
Some(session)
} else {
None
}
}
pub fn take_case_session(&mut self) -> Option<CaseSession> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::CaseSession(session) = old {
Some(session)
} else {
self.data = old;
None
}
}
pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) {
self.data = DataOption::SuspendedReadReq(req);
}
pub fn take_suspended_read_req(&mut self) -> Option<ResumeReadReq> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SuspendedReadReq(req) = old {
Some(req)
} else {
self.data = old;
None
}
}
pub fn set_subscription_id(&mut self, id: u32) {
self.data = DataOption::SubscriptionId(id);
}
pub fn take_subscription_id(&mut self) -> Option<u32> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SubscriptionId(id) = old {
Some(id)
} else {
self.data = old;
None
}
}
pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) {
self.data = DataOption::SuspendedSubscibeReq(req);
}
pub fn take_suspended_subscribe_req(&mut self) -> Option<ResumeSubscribeReq> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SuspendedSubscibeReq(req) = old {
Some(req)
} else {
self.data = old;
None
}
}
pub fn set_data_time(&mut self, expiry_ts: Option<Duration>) {
if let Some(t) = expiry_ts {
self.data = DataOption::Time(t);
}
}
pub fn get_data_time(&self) -> Option<Duration> {
match self.data {
DataOption::Time(t) => Some(t),
_ => None,
}
}
pub(crate) fn send(
&mut self,
tx: &mut Packet,
session: &mut SessionHandle,
) -> Result<bool, Error> {
if self.state == State::Terminate {
info!("Skipping tx for terminated exchange {}", self.id);
return Ok(false);
}
trace!("payload: {:x?}", tx.as_slice());
info!(
"{} with proto id: {} opcode: {}, tlv:\n",
"Sending".blue(),
tx.get_proto_id(),
tx.get_proto_raw_opcode(),
);
//print_tlv_list(tx.as_slice());
tx.proto.exch_id = self.id;
if self.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
self.mrp.pre_send(tx)?;
session.send(tx)?;
Ok(true)
}
}
impl fmt::Display for Exchange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}",
self.id, self.sess_idx, self.role, self.mrp, self.state
)
}
}
pub fn get_role(is_initiator: bool) -> Role {
if is_initiator { if is_initiator {
Role::Initiator Self::Responder
} else { } else {
Role::Responder Self::Initiator
}
} }
} }
pub fn get_complementary_role(is_initiator: bool) -> Role { #[derive(Debug)]
if is_initiator { pub(crate) struct ExchangeCtx {
Role::Responder pub(crate) id: ExchangeId,
} else { pub(crate) role: Role,
Role::Initiator pub(crate) mrp: ReliableMessage,
} pub(crate) state: ExchangeState,
} }
const MAX_EXCHANGES: usize = 8; #[derive(Debug, Clone)]
pub(crate) enum ExchangeState {
pub struct ExchangeMgr { Construction {
// keys: exch-id rx: *mut Packet<'static>,
exchanges: LinearMap<u16, Exchange, MAX_EXCHANGES>, notification: *const Notification,
sess_mgr: SessionMgr,
epoch: Epoch,
}
pub const MAX_MRP_ENTRIES: usize = 4;
impl ExchangeMgr {
#[inline(always)]
pub fn new(epoch: Epoch, rand: Rand) -> Self {
Self {
sess_mgr: SessionMgr::new(epoch, rand),
exchanges: LinearMap::new(),
epoch,
}
}
pub fn get_sess_mgr(&mut self) -> &mut SessionMgr {
&mut self.sess_mgr
}
pub fn _get_with_id(
exchanges: &mut LinearMap<u16, Exchange, MAX_EXCHANGES>,
exch_id: u16,
) -> Option<&mut Exchange> {
exchanges.get_mut(&exch_id)
}
pub fn get_with_id(&mut self, exch_id: u16) -> Option<&mut Exchange> {
ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id)
}
fn _get(
exchanges: &mut LinearMap<u16, Exchange, MAX_EXCHANGES>,
sess_idx: usize,
id: u16,
role: Role,
create_new: bool,
) -> Result<&mut Exchange, Error> {
// I don't prefer that we scan the list twice here (once for contains_key and other)
if !exchanges.contains_key(&(id)) {
if create_new {
// If an exchange doesn't exist, create a new one
info!("Creating new exchange");
let e = Exchange::new(id, sess_idx, role);
if exchanges.insert(id, e).is_err() {
Err(ErrorCode::NoSpace)?;
}
} else {
Err(ErrorCode::NoSpace)?;
}
}
// At this point, we would either have inserted the record if 'create_new' was set
// or it existed already
if let Some(result) = exchanges.get_mut(&id) {
if result.get_role() == role && sess_idx == result.sess_idx {
Ok(result)
} else {
Err(ErrorCode::NoExchange.into())
}
} else {
error!("This should never happen");
Err(ErrorCode::NoSpace.into())
}
}
/// The Exchange Mgr receive is like a big processing function
pub fn recv(&mut self, rx: &mut Packet) -> Result<Option<ExchangeCtx>, Error> {
// Get the session
let index = self.sess_mgr.post_recv(rx)?;
let mut session = self.sess_mgr.get_session_handle(index);
// Decrypt the message
session.recv(self.epoch, rx)?;
// Get the exchange
let exch = ExchangeMgr::_get(
&mut self.exchanges,
index,
rx.proto.exch_id,
get_complementary_role(rx.proto.is_initiator()),
// We create a new exchange, only if the peer is the initiator
rx.proto.is_initiator(),
)?;
// Message Reliability Protocol
exch.mrp.recv(rx, self.epoch)?;
if exch.is_state_open() {
Ok(Some(ExchangeCtx {
exch,
sess: session,
epoch: self.epoch,
}))
} else {
// Instead of an error, we send None here, because it is likely that
// we just processed an acknowledgement that cleared the exchange
Ok(None)
}
}
pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result<bool, Error> {
let exchange =
ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?;
let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx);
exchange.send(tx, &mut session)
}
pub fn purge(&mut self) {
let mut to_purge: LinearMap<u16, (), MAX_EXCHANGES> = LinearMap::new();
for (exch_id, exchange) in self.exchanges.iter() {
if exchange.is_purgeable() {
let _ = to_purge.insert(*exch_id, ());
}
}
for (exch_id, _) in to_purge.iter() {
self.exchanges.remove(exch_id);
}
}
pub fn pending_ack(&mut self) -> Option<u16> {
self.exchanges
.iter()
.find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch))
.map(|(exch_id, _)| *exch_id)
}
pub fn evict_session(&mut self, tx: &mut Packet) -> Result<bool, Error> {
if let Some(index) = self.sess_mgr.get_session_for_eviction() {
info!("Sessions full, vacating session with index: {}", index);
// If we enter here, we have an LRU session that needs to be reclaimed
// As per the spec, we need to send a CLOSE here
let mut session = self.sess_mgr.get_session_handle(index);
secure_channel::common::create_sc_status_report(
tx,
secure_channel::common::SCStatusCodes::CloseSession,
None,
)?;
if let Some((_, exchange)) =
self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index)
{
// Send Close_session on this exchange, and then close the session
// Should this be done for all exchanges?
error!("Sending Close Session");
exchange.send(tx, &mut session)?;
// TODO: This wouldn't actually send it out, because 'transport' isn't owned yet.
}
let remove_exchanges: heapless::Vec<u16, MAX_EXCHANGES> = self
.exchanges
.iter()
.filter_map(|(eid, e)| {
if e.sess_idx == index {
Some(*eid)
} else {
None
}
})
.collect();
info!(
"Terminating the following exchanges: {:?}",
remove_exchanges
);
for exch_id in remove_exchanges {
// Remove from exchange list
self.exchanges.remove(&exch_id);
}
self.sess_mgr.remove(index);
Ok(true)
} else {
Ok(false)
}
}
pub fn add_session(&mut self, clone_data: &CloneData) -> Result<SessionHandle, Error> {
let sess_idx = self.sess_mgr.clone_session(clone_data)?;
Ok(self.sess_mgr.get_session_handle(sess_idx))
}
}
impl fmt::Display for ExchangeMgr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{{ Session Mgr: {},", self.sess_mgr)?;
writeln!(f, " Exchanges: [")?;
for s in &self.exchanges {
writeln!(f, "{{ {}, }},", s.1)?;
}
writeln!(f, " ]")?;
write!(f, "}}")
}
}
#[cfg(test)]
#[allow(clippy::bool_assert_comparison)]
mod tests {
use crate::{
error::ErrorCode,
transport::{
network::Address,
session::{CloneData, SessionMode},
}, },
utils::{epoch::dummy_epoch, rand::dummy_rand}, Active,
}; Acknowledge {
notification: *const Notification,
},
ExchangeSend {
tx: *const Packet<'static>,
rx: *mut Packet<'static>,
notification: *const Notification,
},
ExchangeRecv {
_tx: *const Packet<'static>,
tx_acknowledged: bool,
rx: *mut Packet<'static>,
notification: *const Notification,
},
Complete {
tx: *const Packet<'static>,
notification: *const Notification,
},
CompleteAcknowledge {
_tx: *const Packet<'static>,
notification: *const Notification,
},
Closed,
}
use super::{ExchangeMgr, Role}; pub struct ExchangeCtr<'a> {
pub(crate) exchange: Exchange<'a>,
pub(crate) construction_notification: &'a Notification,
}
#[test] impl<'a> ExchangeCtr<'a> {
fn test_purge() { pub const fn id(&self) -> &ExchangeId {
let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand); self.exchange.id()
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap();
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap();
mgr.purge();
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(),
true
);
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(),
true
);
// Close e1
let e1 = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).unwrap();
e1.close();
mgr.purge();
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(),
false
);
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(),
true
);
} }
fn get_clone_data(peer_sess_id: u16, local_sess_id: u16) -> CloneData { pub async fn get(mut self, rx: &mut Packet<'_>) -> Result<Exchange<'a>, Error> {
CloneData::new( let construction_notification = self.construction_notification;
12341234,
43211234, self.exchange.with_ctx_mut(move |exchange, ctx| {
peer_sess_id, if !matches!(ctx.state, ExchangeState::Active) {
local_sess_id, Err(ErrorCode::NoExchange)?;
Address::default(),
SessionMode::Pase,
)
} }
fn fill_sessions(mgr: &mut ExchangeMgr, count: usize) { let rx: &'static mut Packet<'static> = unsafe { core::mem::transmute(rx) };
let mut local_sess_id = 1; let notification: &'static Notification =
let mut peer_sess_id = 100; unsafe { core::mem::transmute(&exchange.notification) };
for _ in 1..count {
let clone_data = get_clone_data(peer_sess_id, local_sess_id);
match mgr.add_session(&clone_data) {
Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()),
Err(e) => {
if e.code() == ErrorCode::NoSpace {
break;
} else {
panic!("Could not create sessions");
}
}
}
local_sess_id += 1;
peer_sess_id += 1;
}
}
#[cfg(feature = "std")] ctx.state = ExchangeState::Construction { rx, notification };
#[test]
/// We purposefuly overflow the sessions
/// and when the overflow happens, we confirm that
/// - The sessions are evicted in LRU
/// - The exchanges associated with those sessions are evicted too
fn test_sess_evict() {
use crate::transport::packet::{Packet, MAX_TX_BUF_SIZE};
use crate::transport::session::MAX_SESSIONS;
let mut mgr = ExchangeMgr::new(crate::utils::epoch::sys_epoch, dummy_rand); construction_notification.signal(());
fill_sessions(&mut mgr, MAX_SESSIONS + 1); Ok(())
// Sessions are now full from local session id 1 to 16 })?;
// Create exchanges for sessions 2 (i.e. session index 1) and 3 (session index 2) self.exchange.notification.wait().await;
// Exchange IDs are 20 and 30 respectively
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 20, Role::Responder, true).unwrap();
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 2, 30, Role::Responder, true).unwrap();
// Confirm that session ids 1 to MAX_SESSIONS exists Ok(self.exchange)
for i in 1..(MAX_SESSIONS + 1) { }
assert_eq!(mgr.sess_mgr.get_with_id(i as u16).is_none(), false); }
}
// Confirm that the exchanges are around #[derive(Debug, Clone, Eq, PartialEq)]
assert_eq!(mgr.get_with_id(20).is_none(), false); pub struct ExchangeId {
assert_eq!(mgr.get_with_id(30).is_none(), false); pub id: u16,
let mut old_local_sess_id = 1; pub session_id: SessionId,
let mut new_local_sess_id = 100; }
let mut new_peer_sess_id = 200;
impl ExchangeId {
for i in 1..(MAX_SESSIONS + 1) { pub fn load(rx: &Packet) -> Self {
// Now purposefully overflow the sessions by adding another session Self {
let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); id: rx.proto.exch_id,
assert!(matches!( session_id: SessionId::load(rx),
result.map_err(|e| e.code()), }
Err(ErrorCode::NoSpace) }
)); }
#[derive(Debug, Clone, Eq, PartialEq)]
let mut buf = [0; MAX_TX_BUF_SIZE]; pub struct SessionId {
let tx = &mut Packet::new_tx(&mut buf); pub id: u16,
let evicted = mgr.evict_session(tx).unwrap(); pub peer_addr: Address,
assert!(evicted); pub peer_nodeid: Option<u64>,
pub is_encrypted: bool,
let session = mgr }
.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id))
.unwrap(); impl SessionId {
assert_eq!(session.get_peer_sess_id(), new_peer_sess_id); pub fn load(rx: &Packet) -> Self {
Self {
// This should have evicted session with local sess_id id: rx.plain.sess_id,
assert_eq!(mgr.sess_mgr.get_with_id(old_local_sess_id).is_none(), true); peer_addr: rx.peer,
peer_nodeid: rx.plain.get_src_u64(),
new_local_sess_id += 1; is_encrypted: rx.plain.is_encrypted(),
new_peer_sess_id += 1; }
old_local_sess_id += 1; }
}
match i { pub struct Exchange<'a> {
1 => { pub(crate) id: ExchangeId,
// Both exchanges should exist pub(crate) transport: &'a Transport<'a>,
assert_eq!(mgr.get_with_id(20).is_none(), false); pub(crate) notification: Notification,
assert_eq!(mgr.get_with_id(30).is_none(), false); }
}
2 => { impl<'a> Exchange<'a> {
// Exchange 20 would have been evicted pub const fn id(&self) -> &ExchangeId {
assert_eq!(mgr.get_with_id(20).is_none(), true); &self.id
assert_eq!(mgr.get_with_id(30).is_none(), false); }
}
3 => { pub fn matter(&self) -> &Matter<'a> {
// Exchange 20 and 30 would have been evicted self.transport.matter()
assert_eq!(mgr.get_with_id(20).is_none(), true); }
assert_eq!(mgr.get_with_id(30).is_none(), true);
} pub fn transport(&self) -> &Transport<'a> {
_ => {} self.transport
} }
}
// println!("Session mgr {}", mgr.sess_mgr); pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
self.with_session(|sess| {
Ok(Accessor::for_session(
sess,
&self.transport.matter().acl_mgr,
))
})
}
pub fn with_session_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut Session) -> Result<T, Error>,
{
self.with_ctx(|_self, ctx| {
let mut session_mgr = _self.transport.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
f(session_mgr.mut_by_index(sess_index).unwrap())
})
}
pub fn with_session<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&Session) -> Result<T, Error>,
{
self.with_session_mut(|sess| f(sess))
}
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
{
let mut session_mgr = self.transport.session_mgr.borrow_mut();
f(&mut session_mgr)
}
pub async fn initiate(&mut self, fabric_id: u64, node_id: u64) -> Result<Exchange<'a>, Error> {
self.transport.initiate(fabric_id, node_id).await
}
pub async fn acknowledge(&mut self) -> Result<(), Error> {
let wait = self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
if ctx.mrp.is_empty() {
Ok(false)
} else {
ctx.state = ExchangeState::Acknowledge {
notification: &_self.notification as *const _,
};
_self.transport.send_notification.signal(());
Ok(true)
}
})?;
if wait {
self.notification.wait().await;
}
Ok(())
}
pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> {
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
ctx.state = ExchangeState::ExchangeSend {
tx: tx as *const _,
rx: rx as *mut _,
notification: &_self.notification as *const _,
};
_self.transport.send_notification.signal(());
Ok(())
})?;
self.notification.wait().await;
Ok(())
}
pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> {
self.send_complete(tx).await
}
pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> {
let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
ctx.state = ExchangeState::Complete {
tx: tx as *const _,
notification: &_self.notification as *const _,
};
_self.transport.send_notification.signal(());
Ok(())
})?;
self.notification.wait().await;
Ok(())
}
fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,
{
let mut exchanges = self.transport.exchanges.borrow_mut();
let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO
f(self, exchange)
}
fn with_ctx_mut<F, T>(&mut self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result<T, Error>,
{
let mut exchanges = self.transport.exchanges.borrow_mut();
let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO
f(self, exchange)
}
}
impl<'a> Drop for Exchange<'a> {
fn drop(&mut self) {
let _ = self.with_ctx_mut(|_self, ctx| {
ctx.state = ExchangeState::Closed;
_self.transport.send_notification.signal(());
Ok(())
});
} }
} }

View file

@ -23,7 +23,7 @@ pub mod network;
pub mod packet; pub mod packet;
pub mod pipe; pub mod pipe;
pub mod plain_hdr; pub mod plain_hdr;
pub mod proto_ctx;
pub mod proto_hdr; pub mod proto_hdr;
pub mod runner;
pub mod session; pub mod session;
pub mod udp; pub mod udp;

View file

@ -1,41 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use super::exchange::ExchangeCtx;
use super::packet::Packet;
/// This is the context in which a receive packet is being processed
pub struct ProtoCtx<'a, 'b> {
/// This is the exchange context, that includes the exchange and the session
pub exch_ctx: ExchangeCtx<'a>,
/// This is the received buffer for this transaction
pub rx: &'a Packet<'b>,
/// This is the transmit buffer for this transaction
pub tx: &'a mut Packet<'b>,
}
impl<'a, 'b> ProtoCtx<'a, 'b> {
pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self {
Self { exch_ctx, rx, tx }
}
pub fn send(&mut self) -> Result<bool, Error> {
self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)
}
}

View file

@ -0,0 +1,392 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use core::{mem::MaybeUninit, pin::pin};
use crate::{
alloc,
data_model::{core::DataModel, objects::DataModelHandler},
interaction_model::core::PROTO_ID_INTERACTION_MODEL,
transport::network::{Address, IpAddr, Ipv6Addr, SocketAddr},
CommissioningData, Matter,
};
use embassy_futures::select::{select, select3, select_slice, Either};
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel};
use log::{error, info, warn};
use crate::{
error::Error,
secure_channel::{common::PROTO_ID_SECURE_CHANNEL, core::SecureChannel},
transport::packet::{Packet, MAX_RX_BUF_SIZE},
utils::select::EitherUnwrap,
};
use super::{
core::Transport,
exchange::{ExchangeCtr, Notification, MAX_EXCHANGES},
packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE},
pipe::{Chunk, Pipe},
udp::UdpListener,
};
pub type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>;
pub type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>;
type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>;
struct PacketPools {
tx: [TxBuf; MAX_EXCHANGES],
rx: [RxBuf; MAX_EXCHANGES],
sx: [SxBuf; MAX_EXCHANGES],
}
impl PacketPools {
const TX_ELEM: TxBuf = MaybeUninit::uninit();
const RX_ELEM: RxBuf = MaybeUninit::uninit();
const SX_ELEM: SxBuf = MaybeUninit::uninit();
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES];
#[inline(always)]
pub const fn new() -> Self {
Self {
tx: Self::TX_INIT,
rx: Self::RX_INIT,
sx: Self::SX_INIT,
}
}
}
/// This struct implements an executor-agnostic option to run the Matter transport stack end-to-end.
///
/// Since it is not possible to use executor tasks spawning in an executor-agnostic way (yet),
/// the async loops are arranged as one giant future. Therefore, the cost is a slightly slower execution
/// due to the generated future being relatively big and deeply nested.
///
/// Users are free to implement their own async execution loop, by utilizing the `Transport`
/// struct directly with their async executor of choice.
pub struct TransportRunner<'a> {
transport: Transport<'a>,
pools: PacketPools,
}
impl<'a> TransportRunner<'a> {
#[inline(always)]
pub fn new(matter: &'a Matter<'a>) -> Self {
Self::wrap(Transport::new(matter))
}
#[inline(always)]
pub const fn wrap(transport: Transport<'a>) -> Self {
Self {
transport,
pools: PacketPools::new(),
}
}
pub fn transport(&self) -> &Transport {
&self.transport
}
pub async fn run_udp<H>(
&mut self,
tx_buf: &mut TxBuf,
rx_buf: &mut RxBuf,
dev_comm: CommissioningData,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
let udp = UdpListener::new(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
self.transport.matter().port,
))
.await?;
let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() });
let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() });
let tx_pipe = &tx_pipe;
let rx_pipe = &rx_pipe;
let udp = &udp;
let mut tx = pin!(async move {
loop {
{
let mut data = tx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end])
.await?;
data.chunk = None;
tx_pipe.data_consumed_notification.signal(());
}
}
tx_pipe.data_supplied_notification.wait().await;
}
});
let mut rx = pin!(async move {
loop {
{
let mut data = rx_pipe.data.lock().await;
if data.chunk.is_none() {
let (len, addr) = udp.recv(data.buf).await?;
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: Address::Udp(addr),
});
rx_pipe.data_supplied_notification.signal(());
}
}
rx_pipe.data_consumed_notification.wait().await;
}
});
let mut run = pin!(async move { self.run(tx_pipe, rx_pipe, dev_comm, handler).await });
select3(&mut tx, &mut rx, &mut run).await.unwrap()
}
pub async fn run<H>(
&mut self,
tx_pipe: &Pipe<'_>,
rx_pipe: &Pipe<'_>,
dev_comm: CommissioningData,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
info!("Running Matter transport");
let buf = unsafe { self.pools.rx[0].assume_init_mut() };
if self.transport.matter().start_comissioning(dev_comm, buf)? {
info!("Comissioning started");
}
let construction_notification = Notification::new();
let mut rx = pin!(Self::handle_rx(
&self.transport,
&mut self.pools,
rx_pipe,
&construction_notification,
handler
));
let mut tx = pin!(Self::handle_tx(&self.transport, tx_pipe));
select(&mut rx, &mut tx).await.unwrap()
}
async fn handle_rx<H>(
transport: &Transport<'_>,
pools: &mut PacketPools,
rx_pipe: &Pipe<'_>,
construction_notification: &Notification,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
info!("Creating queue for {} exchanges", 1);
let channel = Channel::<NoopRawMutex, _, 1>::new();
info!("Creating {} handlers", MAX_EXCHANGES);
let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new();
info!("Handlers size: {}", core::mem::size_of_val(&handlers));
let pools = &mut *pools as *mut _;
for index in 0..MAX_EXCHANGES {
let channel = &channel;
let handler_id = index;
handlers
.push(async move {
loop {
let exchange_ctr: ExchangeCtr<'_> = channel.recv().await;
info!(
"Handler {}: Got exchange {:?}",
handler_id,
exchange_ctr.id()
);
let result = Self::handle_exchange(
transport,
pools,
handler_id,
exchange_ctr,
handler,
)
.await;
if let Err(err) = result {
warn!(
"Handler {}: Exchange closed because of error: {:?}",
handler_id, err
);
} else {
info!("Handler {}: Exchange completed", handler_id);
}
}
})
.map_err(|_| ())
.unwrap();
}
let mut rx = pin!(async {
loop {
info!("Transport: waiting for incoming packets");
{
let mut data = rx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
rx.peer = chunk.addr;
if let Some(exchange_ctr) =
transport.process_rx(construction_notification, &mut rx)?
{
let exchange_id = exchange_ctr.id().clone();
info!("Transport: got new exchange: {:?}", exchange_id);
channel.send(exchange_ctr).await;
info!("Transport: exchange sent");
transport
.wait_construction(construction_notification, &rx, &exchange_id)
.await?;
info!("Transport: exchange started");
}
data.chunk = None;
rx_pipe.data_consumed_notification.signal(());
}
}
rx_pipe.data_supplied_notification.wait().await
}
#[allow(unreachable_code)]
Ok::<_, Error>(())
});
let result = select(&mut rx, select_slice(&mut handlers)).await;
if let Either::First(result) = result {
if let Err(e) = &result {
error!("Exitting RX loop due to an error: {:?}", e);
}
result?;
}
Ok(())
}
async fn handle_tx(transport: &Transport<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
loop {
{
let mut data = tx_pipe.data.lock().await;
if data.chunk.is_none() {
let mut tx = alloc!(Packet::new_tx(data.buf));
if transport.pull_tx(&mut tx).await? {
data.chunk = Some(Chunk {
start: tx.get_writebuf()?.get_start(),
end: tx.get_writebuf()?.get_tail(),
addr: tx.peer,
});
tx_pipe.data_supplied_notification.signal(());
} else {
break;
}
}
}
tx_pipe.data_consumed_notification.wait().await;
}
transport.wait_tx().await?;
}
}
#[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
async fn handle_exchange<H>(
transport: &Transport<'_>,
pools: *mut PacketPools,
handler_id: usize,
exchange_ctr: ExchangeCtr<'_>,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
let pools = unsafe { pools.as_mut() }.unwrap();
let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() };
let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() };
let rx_status_buf = unsafe { pools.sx[handler_id].assume_init_mut() };
let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut()));
let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut()));
let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?);
match rx.get_proto_id() {
PROTO_ID_SECURE_CHANNEL => {
let sc = SecureChannel::new(transport.matter());
sc.handle(&mut exchange, &mut rx, &mut tx).await?;
transport.notify_changed();
}
PROTO_ID_INTERACTION_MODEL => {
let dm = DataModel::new(handler);
let mut rx_status = alloc!(Packet::new_rx(rx_status_buf));
dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status)
.await?;
transport.notify_changed();
}
other => {
error!("Unknown Proto-ID: {}", other);
}
}
Ok(())
}
}

View file

@ -15,10 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::cell::Cell;
convert::TryInto, use core::convert::TryInto;
sync::{Arc, Mutex, Once}, use std::sync::{Arc, Mutex, Once};
};
use matter::{ use matter::{
attribute_enum, command_enum, attribute_enum, command_enum,
@ -28,11 +27,9 @@ use matter::{
Quality, ATTRIBUTE_LIST, FEATURE_MAP, Quality, ATTRIBUTE_LIST, FEATURE_MAP,
}, },
error::{Error, ErrorCode}, error::{Error, ErrorCode},
interaction_model::{ interaction_model::messages::ib::{attr_list_write, ListOperation},
core::Transaction,
messages::ib::{attr_list_write, ListOperation},
},
tlv::{TLVElement, TagType}, tlv::{TLVElement, TagType},
transport::exchange::Exchange,
utils::rand::Rand, utils::rand::Rand,
}; };
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
@ -132,10 +129,10 @@ pub const WRITE_LIST_MAX: usize = 5;
pub struct EchoCluster { pub struct EchoCluster {
pub data_ver: Dataver, pub data_ver: Dataver,
pub multiplier: u8, pub multiplier: u8,
pub att1: u16, pub att1: Cell<u16>,
pub att2: u16, pub att2: Cell<u16>,
pub att_write: u16, pub att_write: Cell<u16>,
pub att_custom: u32, pub att_custom: Cell<u32>,
} }
impl EchoCluster { impl EchoCluster {
@ -143,10 +140,10 @@ impl EchoCluster {
Self { Self {
data_ver: Dataver::new(rand), data_ver: Dataver::new(rand),
multiplier, multiplier,
att1: 0x1234, att1: Cell::new(0x1234),
att2: 0x5678, att2: Cell::new(0x5678),
att_write: ATTR_WRITE_DEFAULT_VALUE, att_write: Cell::new(ATTR_WRITE_DEFAULT_VALUE),
att_custom: ATTR_CUSTOM_VALUE, att_custom: Cell::new(ATTR_CUSTOM_VALUE),
} }
} }
@ -179,14 +176,14 @@ impl EchoCluster {
} }
} }
pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
let data = data.with_dataver(self.data_ver.get())?; let data = data.with_dataver(self.data_ver.get())?;
match attr.attr_id.try_into()? { match attr.attr_id.try_into()? {
Attributes::Att1(codec) => self.att1 = codec.decode(data)?, Attributes::Att1(codec) => self.att1.set(codec.decode(data)?),
Attributes::Att2(codec) => self.att2 = codec.decode(data)?, Attributes::Att2(codec) => self.att2.set(codec.decode(data)?),
Attributes::AttWrite(codec) => self.att_write = codec.decode(data)?, Attributes::AttWrite(codec) => self.att_write.set(codec.decode(data)?),
Attributes::AttCustom(codec) => self.att_custom = codec.decode(data)?, Attributes::AttCustom(codec) => self.att_custom.set(codec.decode(data)?),
Attributes::AttWriteList(_) => { Attributes::AttWriteList(_) => {
attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))? attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))?
} }
@ -198,8 +195,8 @@ impl EchoCluster {
} }
pub fn invoke( pub fn invoke(
&mut self, &self,
_transaction: &mut Transaction, _exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
@ -222,7 +219,7 @@ impl EchoCluster {
} }
} }
fn write_attr_list(&mut self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { fn write_attr_list(&self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> {
let tc_handle = TestChecker::get().unwrap(); let tc_handle = TestChecker::get().unwrap();
let mut tc = tc_handle.lock().unwrap(); let mut tc = tc_handle.lock().unwrap();
match op { match op {
@ -272,18 +269,18 @@ impl Handler for EchoCluster {
EchoCluster::read(self, attr, encoder) EchoCluster::read(self, attr, encoder)
} }
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
EchoCluster::write(self, attr, data) EchoCluster::write(self, attr, data)
} }
fn invoke( fn invoke(
&mut self, &self,
transaction: &mut Transaction, exchange: &Exchange,
cmd: &CmdDetails, cmd: &CmdDetails,
data: &TLVElement, data: &TLVElement,
encoder: CmdDataEncoder, encoder: CmdDataEncoder,
) -> Result<(), Error> { ) -> Result<(), Error> {
EchoCluster::invoke(self, transaction, cmd, data, encoder) EchoCluster::invoke(self, exchange, cmd, data, encoder)
} }
} }

View file

@ -1,8 +1,6 @@
use core::time;
use std::thread;
use log::{info, warn}; use log::{info, warn};
use matter::{ use matter::{
error::ErrorCode,
interaction_model::{ interaction_model::{
core::{IMStatusCode, OpCode}, core::{IMStatusCode, OpCode},
messages::{ messages::{
@ -14,17 +12,12 @@ use matter::{
}, },
}, },
tlv::{self, FromTLV, TLVArray, ToTLV}, tlv::{self, FromTLV, TLVArray, ToTLV},
transport::{
exchange::{self, Exchange},
session::NocCatIds,
},
Matter,
}; };
use super::{ use super::{
attributes::assert_attr_report, attributes::assert_attr_report,
commands::{assert_inv_response, ExpectedInvResp}, commands::{assert_inv_response, ExpectedInvResp},
im_engine::{ImEngine, ImInput, IM_ENGINE_PEER_ID}, im_engine::{ImEngine, ImEngineHandler, ImInput, ImOutput},
}; };
pub enum WriteResponse<'a> { pub enum WriteResponse<'a> {
@ -38,72 +31,71 @@ pub enum TimedInvResponse<'a> {
} }
impl<'a> ImEngine<'a> { impl<'a> ImEngine<'a> {
pub fn read_reqs(input: &[AttrPath], expected: &[AttrResp]) {
let im = ImEngine::new_default();
im.add_default_acl();
im.handle_read_reqs(&im.handler(), input, expected);
}
// Helper for handling Read Req sequences for this file // Helper for handling Read Req sequences for this file
pub fn handle_read_reqs( pub fn handle_read_reqs(
&mut self, &self,
peer_node_id: u64, handler: &ImEngineHandler,
input: &[AttrPath], input: &[AttrPath],
expected: &[AttrResp], expected: &[AttrResp],
) { ) {
let mut out_buf = [0u8; 400]; let mut out = heapless::Vec::<_, 1>::new();
let received = self.gen_read_reqs_output(peer_node_id, input, None, &mut out_buf); let received = self.gen_read_reqs_output(handler, input, None, &mut out);
assert_attr_report(&received, expected) assert_attr_report(&received, expected)
} }
pub fn new_with_read_reqs( pub fn gen_read_reqs_output<'c, const N: usize>(
matter: &'a Matter<'a>, &self,
handler: &ImEngineHandler,
input: &[AttrPath], input: &[AttrPath],
expected: &[AttrResp], dataver_filters: Option<TLVArray<'_, DataVersionFilter>>,
) -> Self { out: &'c mut heapless::Vec<ImOutput, N>,
let mut im = Self::new(matter); ) -> ReportDataMsg<'c> {
let mut out_buf = [0u8; 400];
let received = im.gen_read_reqs_output(IM_ENGINE_PEER_ID, input, None, &mut out_buf);
assert_attr_report(&received, expected);
im
}
pub fn gen_read_reqs_output<'b>(
&mut self,
peer_node_id: u64,
input: &[AttrPath],
dataver_filters: Option<TLVArray<'b, DataVersionFilter>>,
out_buf: &'b mut [u8],
) -> ReportDataMsg<'b> {
let mut read_req = ReadReq::new(true).set_attr_requests(input); let mut read_req = ReadReq::new(true).set_attr_requests(input);
read_req.dataver_filters = dataver_filters; read_req.dataver_filters = dataver_filters;
let mut input = ImInput::new(OpCode::ReadRequest, &read_req); let input = ImInput::new(OpCode::ReadRequest, &read_req);
input.set_peer_node_id(peer_node_id);
let (_, out_buf) = self.process(&input, out_buf); self.process(handler, &[&input], out).unwrap();
tlv::print_tlv_list(out_buf); for o in &*out {
let root = tlv::get_root_node_struct(out_buf).unwrap(); tlv::print_tlv_list(&o.data);
}
let root = tlv::get_root_node_struct(&out[0].data).unwrap();
ReportDataMsg::from_tlv(&root).unwrap() ReportDataMsg::from_tlv(&root).unwrap()
} }
pub fn write_reqs(input: &[AttrData], expected: &[AttrStatus]) {
let im = ImEngine::new_default();
im.add_default_acl();
im.handle_write_reqs(&im.handler(), input, expected);
}
pub fn handle_write_reqs( pub fn handle_write_reqs(
&mut self, &self,
peer_node_id: u64, handler: &ImEngineHandler,
peer_cat_ids: Option<&NocCatIds>,
input: &[AttrData], input: &[AttrData],
expected: &[AttrStatus], expected: &[AttrStatus],
) { ) {
let mut out_buf = [0u8; 400];
let write_req = WriteReq::new(false, input); let write_req = WriteReq::new(false, input);
let mut input = ImInput::new(OpCode::WriteRequest, &write_req); let input = ImInput::new(OpCode::WriteRequest, &write_req);
input.set_peer_node_id(peer_node_id); let mut out = heapless::Vec::<_, 1>::new();
if let Some(cat_ids) = peer_cat_ids { self.process(handler, &[&input], &mut out).unwrap();
input.set_cat_ids(cat_ids);
for o in &out {
tlv::print_tlv_list(&o.data);
} }
let (_, out_buf) = self.process(&input, &mut out_buf); let root = tlv::get_root_node_struct(&out[0].data).unwrap();
tlv::print_tlv_list(out_buf);
let root = tlv::get_root_node_struct(out_buf).unwrap();
let mut index = 0; let mut index = 0;
let response_iter = root let response_iter = root
@ -124,194 +116,184 @@ impl<'a> ImEngine<'a> {
assert_eq!(index, expected.len()); assert_eq!(index, expected.len());
} }
pub fn new_with_write_reqs( pub fn commands(input: &[CmdData], expected: &[ExpectedInvResp]) {
matter: &'a Matter<'a>, let im = ImEngine::new_default();
input: &[AttrData],
expected: &[AttrStatus],
) -> Self {
let mut im = Self::new(matter);
im.handle_write_reqs(IM_ENGINE_PEER_ID, None, input, expected); im.add_default_acl();
im.handle_commands(&im.handler(), input, expected)
im
} }
// Helper for handling Invoke Command sequences // Helper for handling Invoke Command sequences
pub fn handle_commands( pub fn handle_commands(
&mut self, &self,
peer_node_id: u64, handler: &ImEngineHandler,
input: &[CmdData], input: &[CmdData],
expected: &[ExpectedInvResp], expected: &[ExpectedInvResp],
) { ) {
let mut out_buf = [0u8; 400];
let req = InvReq { let req = InvReq {
suppress_response: Some(false), suppress_response: Some(false),
timed_request: Some(false), timed_request: Some(false),
inv_requests: Some(TLVArray::Slice(input)), inv_requests: Some(TLVArray::Slice(input)),
}; };
let mut input = ImInput::new(OpCode::InvokeRequest, &req); let input = ImInput::new(OpCode::InvokeRequest, &req);
input.set_peer_node_id(peer_node_id);
let (_, out_buf) = self.process(&input, &mut out_buf); let mut out = heapless::Vec::<_, 1>::new();
tlv::print_tlv_list(out_buf); self.process(handler, &[&input], &mut out).unwrap();
let root = tlv::get_root_node_struct(out_buf).unwrap();
for o in &out {
tlv::print_tlv_list(&o.data);
}
let root = tlv::get_root_node_struct(&out[0].data).unwrap();
let resp = msg::InvResp::from_tlv(&root).unwrap(); let resp = msg::InvResp::from_tlv(&root).unwrap();
assert_inv_response(&resp, expected) assert_inv_response(&resp, expected)
} }
pub fn new_with_commands( fn gen_timed_reqs_output<const N: usize>(
matter: &'a Matter<'a>, &self,
input: &[CmdData], handler: &ImEngineHandler,
expected: &[ExpectedInvResp],
) -> Self {
let mut im = ImEngine::new(matter);
im.handle_commands(IM_ENGINE_PEER_ID, input, expected);
im
}
fn handle_timed_reqs<'b>(
&mut self,
opcode: OpCode, opcode: OpCode,
request: &dyn ToTLV, request: &dyn ToTLV,
timeout: u16, timeout: u16,
delay: u16, delay: u16,
output: &'b mut [u8], out: &mut heapless::Vec<ImOutput, N>,
) -> (u8, &'b [u8]) { ) {
// Use the same exchange for all parts of the transaction let mut inp = heapless::Vec::<_, 2>::new();
self.exch = Some(Exchange::new(1, 0, exchange::Role::Responder));
let timed_req = TimedReq { timeout };
let im_input = ImInput::new_delayed(OpCode::TimedRequest, &timed_req, Some(delay));
if timeout != 0 { if timeout != 0 {
// Send Timed Req // Send Timed Req
let mut tmp_buf = [0u8; 400]; inp.push(&im_input).map_err(|_| ErrorCode::NoSpace).unwrap();
let timed_req = TimedReq { timeout };
let im_input = ImInput::new(OpCode::TimedRequest, &timed_req);
let (_, out_buf) = self.process(&im_input, &mut tmp_buf);
tlv::print_tlv_list(out_buf);
} else { } else {
warn!("Skipping timed request"); warn!("Skipping timed request");
} }
// Process any delays
let delay = time::Duration::from_millis(delay.into());
thread::sleep(delay);
// Send Write Req // Send Write Req
let input = ImInput::new(opcode, request); let input = ImInput::new(opcode, request);
let (resp_opcode, output) = self.process(&input, output); inp.push(&input).map_err(|_| ErrorCode::NoSpace).unwrap();
(resp_opcode, output)
self.process(handler, &inp, out).unwrap();
drop(inp);
for o in out {
tlv::print_tlv_list(&o.data);
}
} }
// Helper for handling Write Attribute sequences pub fn timed_write_reqs(
pub fn handle_timed_write_reqs(
&mut self,
input: &[AttrData], input: &[AttrData],
expected: &WriteResponse, expected: &WriteResponse,
timeout: u16, timeout: u16,
delay: u16, delay: u16,
) { ) {
let mut out_buf = [0u8; 400]; let im = ImEngine::new_default();
im.add_default_acl();
im.handle_timed_write_reqs(&im.handler(), input, expected, timeout, delay);
}
// Helper for handling Write Attribute sequences
pub fn handle_timed_write_reqs(
&self,
handler: &ImEngineHandler,
input: &[AttrData],
expected: &WriteResponse,
timeout: u16,
delay: u16,
) {
let mut out = heapless::Vec::<_, 2>::new();
let write_req = WriteReq::new(false, input); let write_req = WriteReq::new(false, input);
let (resp_opcode, out_buf) = self.handle_timed_reqs( self.gen_timed_reqs_output(
handler,
OpCode::WriteRequest, OpCode::WriteRequest,
&write_req, &write_req,
timeout, timeout,
delay, delay,
&mut out_buf, &mut out,
); );
tlv::print_tlv_list(out_buf);
let root = tlv::get_root_node_struct(out_buf).unwrap(); let out = &out[out.len() - 1];
let root = tlv::get_root_node_struct(&out.data).unwrap();
match expected { match expected {
WriteResponse::TransactionSuccess(t) => { WriteResponse::TransactionSuccess(t) => {
assert_eq!( assert_eq!(out.action, OpCode::WriteResponse);
num::FromPrimitive::from_u8(resp_opcode),
Some(OpCode::WriteResponse)
);
let resp = WriteResp::from_tlv(&root).unwrap(); let resp = WriteResp::from_tlv(&root).unwrap();
assert_eq!(resp.write_responses, t); assert_eq!(resp.write_responses, t);
} }
WriteResponse::TransactionError => { WriteResponse::TransactionError => {
assert_eq!( assert_eq!(out.action, OpCode::StatusResponse);
num::FromPrimitive::from_u8(resp_opcode),
Some(OpCode::StatusResponse)
);
let status_resp = StatusResp::from_tlv(&root).unwrap(); let status_resp = StatusResp::from_tlv(&root).unwrap();
assert_eq!(status_resp.status, IMStatusCode::Timeout); assert_eq!(status_resp.status, IMStatusCode::Timeout);
} }
} }
} }
pub fn new_with_timed_write_reqs( pub fn timed_commands(
matter: &'a Matter<'a>,
input: &[AttrData],
expected: &WriteResponse,
timeout: u16,
delay: u16,
) -> Self {
let mut im = ImEngine::new(matter);
im.handle_timed_write_reqs(input, expected, timeout, delay);
im
}
// Helper for handling Invoke Command sequences
pub fn handle_timed_commands(
&mut self,
input: &[CmdData], input: &[CmdData],
expected: &TimedInvResponse, expected: &TimedInvResponse,
timeout: u16, timeout: u16,
delay: u16, delay: u16,
set_timed_request: bool, set_timed_request: bool,
) { ) {
let mut out_buf = [0u8; 400]; let im = ImEngine::new_default();
im.add_default_acl();
im.handle_timed_commands(
&im.handler(),
input,
expected,
timeout,
delay,
set_timed_request,
);
}
// Helper for handling Invoke Command sequences
pub fn handle_timed_commands(
&self,
handler: &ImEngineHandler,
input: &[CmdData],
expected: &TimedInvResponse,
timeout: u16,
delay: u16,
set_timed_request: bool,
) {
let mut out = heapless::Vec::<_, 2>::new();
let req = InvReq { let req = InvReq {
suppress_response: Some(false), suppress_response: Some(false),
timed_request: Some(set_timed_request), timed_request: Some(set_timed_request),
inv_requests: Some(TLVArray::Slice(input)), inv_requests: Some(TLVArray::Slice(input)),
}; };
let (resp_opcode, out_buf) = self.gen_timed_reqs_output(
self.handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); handler,
tlv::print_tlv_list(out_buf); OpCode::InvokeRequest,
let root = tlv::get_root_node_struct(out_buf).unwrap(); &req,
timeout,
delay,
&mut out,
);
let out = &out[out.len() - 1];
let root = tlv::get_root_node_struct(&out.data).unwrap();
match expected { match expected {
TimedInvResponse::TransactionSuccess(t) => { TimedInvResponse::TransactionSuccess(t) => {
assert_eq!( assert_eq!(out.action, OpCode::InvokeResponse);
num::FromPrimitive::from_u8(resp_opcode),
Some(OpCode::InvokeResponse)
);
let resp = msg::InvResp::from_tlv(&root).unwrap(); let resp = msg::InvResp::from_tlv(&root).unwrap();
assert_inv_response(&resp, t) assert_inv_response(&resp, t)
} }
TimedInvResponse::TransactionError(e) => { TimedInvResponse::TransactionError(e) => {
assert_eq!( assert_eq!(out.action, OpCode::StatusResponse);
num::FromPrimitive::from_u8(resp_opcode),
Some(OpCode::StatusResponse)
);
let status_resp = StatusResp::from_tlv(&root).unwrap(); let status_resp = StatusResp::from_tlv(&root).unwrap();
assert_eq!(status_resp.status, *e); assert_eq!(status_resp.status, *e);
} }
} }
} }
pub fn new_with_timed_commands(
matter: &'a Matter<'a>,
input: &[CmdData],
expected: &TimedInvResponse,
timeout: u16,
delay: u16,
set_timed_request: bool,
) -> Self {
let mut im = ImEngine::new(matter);
im.handle_timed_commands(input, expected, timeout, delay, set_timed_request);
im
}
} }

View file

@ -17,14 +17,19 @@
use crate::common::echo_cluster; use crate::common::echo_cluster;
use core::borrow::Borrow; use core::borrow::Borrow;
use core::future::pending;
use core::time::Duration;
use embassy_futures::select::select3;
use matter::{ use matter::{
acl::{AclEntry, AuthMode}, acl::{AclEntry, AuthMode},
data_model::{ data_model::{
cluster_basic_information::{self, BasicInfoConfig}, cluster_basic_information::{self, BasicInfoConfig},
cluster_on_off::{self, OnOffCluster}, cluster_on_off::{self, OnOffCluster},
core::DataModel,
device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE},
objects::{Endpoint, Node, Privilege}, objects::{
AttrData, AttrDataEncoder, AttrDetails, Endpoint, Handler, HandlerCompat, Metadata,
Node, NonBlockingHandler, Privilege,
},
root_endpoint::{self, RootEndpointHandler}, root_endpoint::{self, RootEndpointHandler},
sdm::{ sdm::{
admin_commissioning, admin_commissioning,
@ -36,21 +41,24 @@ use matter::{
descriptor::{self, DescriptorCluster}, descriptor::{self, DescriptorCluster},
}, },
}, },
error::Error, error::{Error, ErrorCode},
handler_chain_type, handler_chain_type,
interaction_model::core::{InteractionModel, OpCode}, interaction_model::core::{OpCode, PROTO_ID_INTERACTION_MODEL},
mdns::Mdns, mdns::DummyMdns,
secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL, spake2p::VerifierData},
tlv::{TLVWriter, TagType, ToTLV}, tlv::{TLVWriter, TagType, ToTLV},
transport::packet::Packet,
transport::{ transport::{
exchange::{self, Exchange, ExchangeCtx}, exchange::Notification,
network::{Address, IpAddr, Ipv4Addr, SocketAddr}, packet::{Packet, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE},
packet::MAX_RX_BUF_SIZE, pipe::Pipe,
proto_ctx::ProtoCtx, runner::TransportRunner,
session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode},
}, },
utils::{rand::dummy_rand, writebuf::WriteBuf}, transport::{
Matter, network::Address,
session::{CaseDetails, CloneData, NocCatIds, SessionMode},
},
utils::select::EitherUnwrap,
CommissioningData, Matter, MATTER_PORT,
}; };
use super::echo_cluster::EchoCluster; use super::echo_cluster::EchoCluster;
@ -74,66 +82,9 @@ impl DevAttDataFetcher for DummyDevAtt {
} }
pub const IM_ENGINE_PEER_ID: u64 = 445566; pub const IM_ENGINE_PEER_ID: u64 = 445566;
pub const IM_ENGINE_REMOTE_PEER_ID: u64 = 123456;
pub struct ImInput<'a> { const NODE: Node<'static> = Node {
action: OpCode,
data: &'a dyn ToTLV,
peer_id: u64,
cat_ids: NocCatIds,
}
impl<'a> ImInput<'a> {
pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self {
Self {
action,
data,
peer_id: IM_ENGINE_PEER_ID,
cat_ids: Default::default(),
}
}
pub fn set_peer_node_id(&mut self, peer: u64) {
self.peer_id = peer;
}
pub fn set_cat_ids(&mut self, cat_ids: &NocCatIds) {
self.cat_ids = *cat_ids;
}
}
pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'a>, EchoCluster | RootEndpointHandler<'a>);
pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> {
#[cfg(feature = "std")]
use matter::utils::epoch::sys_epoch as epoch;
#[cfg(not(feature = "std"))]
use matter::utils::epoch::dummy_epoch as epoch;
Matter::new(&BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540)
}
/// An Interaction Model Engine to facilitate easy testing
pub struct ImEngine<'a> {
pub matter: &'a Matter<'a>,
pub im: InteractionModel<DataModel<'a, DmHandler<'a>>>,
// By default, a new exchange is created for every run, if you wish to instead using a specific
// exchange, set this variable. This is helpful in situations where you have to run multiple
// actions in the same transaction (exchange)
pub exch: Option<Exchange>,
}
impl<'a> ImEngine<'a> {
/// Create the interaction model engine
pub fn new(matter: &'a Matter<'a>) -> Self {
let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
// Only allow the standard peer node id of the IM Engine
default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
matter.acl_mgr.borrow_mut().add(default_acl).unwrap();
let dm = DataModel::new(
matter.borrow(),
&Node {
id: 0, id: 0,
endpoints: &[ endpoints: &[
Endpoint { Endpoint {
@ -160,97 +111,292 @@ impl<'a> ImEngine<'a> {
device_type: DEV_TYPE_ON_OFF_LIGHT, device_type: DEV_TYPE_ON_OFF_LIGHT,
}, },
], ],
}, };
root_endpoint::handler(0, matter)
pub struct ImInput<'a> {
action: OpCode,
data: &'a dyn ToTLV,
delay: Option<u16>,
}
impl<'a> ImInput<'a> {
pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self {
Self::new_delayed(action, data, None)
}
pub fn new_delayed(action: OpCode, data: &'a dyn ToTLV, delay: Option<u16>) -> Self {
Self {
action,
data,
delay,
}
}
}
pub struct ImOutput {
pub action: OpCode,
pub data: heapless::Vec<u8, MAX_TX_BUF_SIZE>,
}
pub struct ImEngineHandler<'a> {
handler: handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'static>, EchoCluster | RootEndpointHandler<'a>),
}
impl<'a> ImEngineHandler<'a> {
pub fn new(matter: &'a Matter<'a>) -> Self {
let handler = root_endpoint::handler(0, matter)
.chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow()))
.chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow()))
.chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow()))
.chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow()));
);
Self { Self { handler }
matter,
im: InteractionModel(dm),
exch: None,
}
} }
pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster {
match endpoint { match endpoint {
0 => &self.im.0.handler.next.next.next.handler, 0 => &self.handler.next.next.next.handler,
1 => &self.im.0.handler.next.handler, 1 => &self.handler.next.handler,
_ => panic!(), _ => panic!(),
} }
} }
}
/// Run a transaction through the interaction model engine impl<'a> Handler for ImEngineHandler<'a> {
pub fn process<'b>(&mut self, input: &ImInput, data_out: &'b mut [u8]) -> (u8, &'b [u8]) { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
let mut new_exch = Exchange::new(1, 0, exchange::Role::Responder); self.handler.read(attr, encoder)
// Choose whether to use a new exchange, or use the one from the ImEngine configuration
let exch = self.exch.as_mut().unwrap_or(&mut new_exch);
let mut sess_mgr = SessionMgr::new(*self.matter.borrow(), *self.matter.borrow());
let clone_data = CloneData::new(
123456,
input.peer_id,
10,
30,
Address::Udp(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
5542,
)),
SessionMode::Case(CaseDetails::new(1, &input.cat_ids)),
);
let sess_idx = sess_mgr.clone_session(&clone_data).unwrap();
let sess = sess_mgr.get_session_handle(sess_idx);
let exch_ctx = ExchangeCtx {
exch,
sess,
epoch: *self.matter.borrow(),
};
let mut rx_buf = [0; MAX_RX_BUF_SIZE];
let mut tx_buf = [0; 1440]; // For the long read tests to run unchanged
let mut rx = Packet::new_rx(&mut rx_buf);
let mut tx = Packet::new_tx(&mut tx_buf);
// Create fake rx packet
rx.set_proto_id(0x01);
rx.set_proto_opcode(input.action as u8);
rx.peer = Address::default();
{
let mut buf = [0u8; 400];
let mut wb = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut wb);
input.data.to_tlv(&mut tw, TagType::Anonymous).unwrap();
let input_data = wb.as_slice();
let in_data_len = input_data.len();
let rx_buf = rx.as_mut_slice();
rx_buf[..in_data_len].copy_from_slice(input_data);
rx.get_parsebuf().unwrap().set_len(in_data_len);
} }
let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
self.im.handle(&mut ctx).unwrap(); self.handler.write(attr, data)
let out_data_len = ctx.tx.as_slice().len(); }
data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice());
let response = ctx.tx.get_proto_raw_opcode(); fn invoke(
(response, &data_out[..out_data_len]) &self,
exchange: &matter::transport::exchange::Exchange,
cmd: &matter::data_model::objects::CmdDetails,
data: &matter::tlv::TLVElement,
encoder: matter::data_model::objects::CmdDataEncoder,
) -> Result<(), Error> {
self.handler.invoke(exchange, cmd, data, encoder)
} }
} }
// TODO - Remove? impl<'a> NonBlockingHandler for ImEngineHandler<'a> {}
// // Create an Interaction Model, Data Model and run a rx/tx transaction through it
// pub fn im_engine<'a>( impl<'a> Metadata for ImEngineHandler<'a> {
// matter: &'a Matter, type MetadataGuard<'g> = Node<'g> where Self: 'g;
// action: OpCode,
// data: &dyn ToTLV, fn lock(&self) -> Self::MetadataGuard<'_> {
// data_out: &'a mut [u8], NODE
// ) -> (DmHandler<'a>, u8, &'a mut [u8]) { }
// let mut engine = ImEngine::new(matter); }
// let input = ImInput::new(action, data);
// let (response, output) = engine.process(&input, data_out); static mut DNS: DummyMdns = DummyMdns;
// (engine.dm.handler, response, output)
// } /// An Interaction Model Engine to facilitate easy testing
pub struct ImEngine<'a> {
pub matter: Matter<'a>,
cat_ids: NocCatIds,
}
impl<'a> ImEngine<'a> {
pub fn new_default() -> Self {
Self::new(Default::default())
}
/// Create the interaction model engine
pub fn new(cat_ids: NocCatIds) -> Self {
#[cfg(feature = "std")]
use matter::utils::epoch::sys_epoch as epoch;
#[cfg(not(feature = "std"))]
use matter::utils::epoch::dummy_epoch as epoch;
#[cfg(feature = "std")]
use matter::utils::rand::sys_rand as rand;
#[cfg(not(feature = "std"))]
use matter::utils::rand::dummy_rand as rand;
let matter = Matter::new(
&BASIC_INFO,
&DummyDevAtt,
unsafe { &mut DNS },
epoch,
rand,
MATTER_PORT,
);
Self { matter, cat_ids }
}
pub fn add_default_acl(&self) {
// Only allow the standard peer node id of the IM Engine
let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap();
}
pub fn handler(&self) -> ImEngineHandler<'_> {
ImEngineHandler::new(&self.matter)
}
pub fn process<const N: usize>(
&self,
handler: &ImEngineHandler,
input: &[&ImInput],
out: &mut heapless::Vec<ImOutput, N>,
) -> Result<(), Error> {
let mut runner = TransportRunner::new(&self.matter);
let clone_data = CloneData::new(
IM_ENGINE_REMOTE_PEER_ID,
IM_ENGINE_PEER_ID,
1,
1,
Address::default(),
SessionMode::Case(CaseDetails::new(1, &self.cat_ids)),
);
let sess_idx = runner
.transport()
.session_mgr
.borrow_mut()
.clone_session(&clone_data)
.unwrap();
let mut tx_pipe_buf = [0; MAX_RX_BUF_SIZE];
let mut rx_pipe_buf = [0; MAX_TX_BUF_SIZE];
let mut tx_buf = [0; MAX_RX_BUF_SIZE];
let mut rx_buf = [0; MAX_TX_BUF_SIZE];
let tx_pipe = Pipe::new(&mut tx_buf);
let rx_pipe = Pipe::new(&mut rx_buf);
let tx_pipe = &tx_pipe;
let rx_pipe = &rx_pipe;
let tx_pipe_buf = &mut tx_pipe_buf;
let rx_pipe_buf = &mut rx_pipe_buf;
let handler = &handler;
let runner = &mut runner;
let mut msg_ctr = runner
.transport()
.session_mgr
.borrow_mut()
.mut_by_index(sess_idx)
.unwrap()
.get_msg_ctr();
let resp_notif = Notification::new();
let resp_notif = &resp_notif;
embassy_futures::block_on(async move {
select3(
runner.run(
tx_pipe,
rx_pipe,
CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456, *self.matter.borrow()),
discriminator: 250,
},
&HandlerCompat(handler),
),
async move {
let mut acknowledge = false;
for ip in input {
Self::send(ip, tx_pipe_buf, rx_pipe, msg_ctr, acknowledge).await?;
resp_notif.wait().await;
if let Some(delay) = ip.delay {
if delay > 0 {
#[cfg(feature = "std")]
std::thread::sleep(Duration::from_millis(delay as _));
}
}
msg_ctr += 2;
acknowledge = true;
}
pending::<()>().await;
Ok(())
},
async move {
out.clear();
while out.len() < input.len() {
let (len, _) = tx_pipe.recv(rx_pipe_buf).await;
let mut rx = Packet::new_rx(&mut rx_pipe_buf[..len]);
rx.plain_hdr_decode()?;
rx.proto_decode(IM_ENGINE_REMOTE_PEER_ID, Some(&[0u8; 16]))?;
if rx.get_proto_id() != PROTO_ID_SECURE_CHANNEL
|| rx.get_proto_opcode::<secure_channel::common::OpCode>()?
!= secure_channel::common::OpCode::MRPStandAloneAck
{
out.push(ImOutput {
action: rx.get_proto_opcode()?,
data: heapless::Vec::from_slice(rx.as_slice())
.map_err(|_| ErrorCode::NoSpace)?,
})
.map_err(|_| ErrorCode::NoSpace)?;
resp_notif.signal(());
}
}
Ok(())
},
)
.await
.unwrap()
})?;
Ok(())
}
async fn send(
input: &ImInput<'_>,
tx_buf: &mut [u8],
rx_pipe: &Pipe<'_>,
msg_ctr: u32,
acknowledge: bool,
) -> Result<(), Error> {
let mut tx = Packet::new_tx(tx_buf);
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(input.action as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
input.data.to_tlv(&mut tw, TagType::Anonymous)?;
tx.plain.ctr = msg_ctr + 1;
tx.plain.sess_id = 1;
tx.proto.set_initiator();
if acknowledge {
tx.proto.set_ack(msg_ctr - 1);
}
tx.proto_encode(
Address::default(),
Some(IM_ENGINE_REMOTE_PEER_ID),
IM_ENGINE_PEER_ID,
false,
Some(&[0u8; 16]),
)?;
rx_pipe.send(Address::default(), tx.as_slice()).await;
Ok(())
}
}

View file

@ -26,7 +26,6 @@ use matter::{
messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter},
messages::GenericPath, messages::GenericPath,
}, },
mdns::DummyMdns,
tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType},
}; };
@ -35,7 +34,7 @@ use crate::{
common::{ common::{
attributes::*, attributes::*,
echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE},
im_engine::{matter, ImEngine}, im_engine::{ImEngine, IM_ENGINE_PEER_ID},
init_env_logger, init_env_logger,
}, },
}; };
@ -62,30 +61,28 @@ fn wc_read_attribute() {
Some(echo_cluster::AttributesDiscriminants::Att1 as u32), Some(echo_cluster::AttributesDiscriminants::Att1 as u32),
); );
let peer = 98765; let im = ImEngine::new_default();
let mut mdns = DummyMdns {}; let handler = im.handler();
let matter = matter(&mut mdns);
let mut im = ImEngine::new(&matter);
// Test1: Empty Response as no ACL matches // Test1: Empty Response as no ACL matches
let input = &[AttrPath::new(&wc_att1)]; let input = &[AttrPath::new(&wc_att1)];
let expected = &[]; let expected = &[];
im.handle_read_reqs(peer, input, expected); im.handle_read_reqs(&handler, input, expected);
// Add ACL to allow our peer to only access endpoint 0 // Add ACL to allow our peer to only access endpoint 0
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
acl.add_target(Target::new(Some(0), None, None)).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test2: Only Single response as only single endpoint is allowed // Test2: Only Single response as only single endpoint is allowed
let input = &[AttrPath::new(&wc_att1)]; let input = &[AttrPath::new(&wc_att1)];
let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))];
im.handle_read_reqs(peer, input, expected); im.handle_read_reqs(&handler, input, expected);
// Add ACL to allow our peer to also access endpoint 1 // Add ACL to allow our peer to also access endpoint 1
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
acl.add_target(Target::new(Some(1), None, None)).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
@ -95,7 +92,7 @@ fn wc_read_attribute() {
attr_data_path!(ep0_att1, ElementType::U16(0x1234)), attr_data_path!(ep0_att1, ElementType::U16(0x1234)),
attr_data_path!(ep1_att1, ElementType::U16(0x1234)), attr_data_path!(ep1_att1, ElementType::U16(0x1234)),
]; ];
im.handle_read_reqs(peer, input, expected); im.handle_read_reqs(&handler, input, expected);
} }
#[test] #[test]
@ -115,25 +112,23 @@ fn exact_read_attribute() {
Some(echo_cluster::AttributesDiscriminants::Att1 as u32), Some(echo_cluster::AttributesDiscriminants::Att1 as u32),
); );
let peer = 98765; let im = ImEngine::new_default();
let mut mdns = DummyMdns {}; let handler = im.handler();
let matter = matter(&mut mdns);
let mut im = ImEngine::new(&matter);
// Test1: Unsupported Access error as no ACL matches // Test1: Unsupported Access error as no ACL matches
let input = &[AttrPath::new(&wc_att1)]; let input = &[AttrPath::new(&wc_att1)];
let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)];
im.handle_read_reqs(peer, input, expected); im.handle_read_reqs(&handler, input, expected);
// Add ACL to allow our peer to access any endpoint // Add ACL to allow our peer to access any endpoint
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test2: Only Single response as only single endpoint is allowed // Test2: Only Single response as only single endpoint is allowed
let input = &[AttrPath::new(&wc_att1)]; let input = &[AttrPath::new(&wc_att1)];
let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))];
im.handle_read_reqs(peer, input, expected); im.handle_read_reqs(&handler, input, expected);
} }
#[test] #[test]
@ -177,52 +172,54 @@ fn wc_write_attribute() {
EncodeValue::Closure(&attr_data1), EncodeValue::Closure(&attr_data1),
)]; )];
let peer = 98765; let im = ImEngine::new_default();
let mut mdns = DummyMdns {}; let handler = im.handler();
let matter = matter(&mut mdns);
let mut im = ImEngine::new(&matter);
// Test 1: Wildcard write to an attribute without permission should return // Test 1: Wildcard write to an attribute without permission should return
// no error // no error
im.handle_write_reqs(peer, None, input0, &[]); im.handle_write_reqs(&handler, input0, &[]);
assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); assert_eq!(
ATTR_WRITE_DEFAULT_VALUE,
handler.echo_cluster(0).att_write.get()
);
// Add ACL to allow our peer to access one endpoint // Add ACL to allow our peer to access one endpoint
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
acl.add_target(Target::new(Some(0), None, None)).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test 2: Wildcard write to attributes will only return attributes // Test 2: Wildcard write to attributes will only return attributes
// where the writes were successful // where the writes were successful
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input0, input0,
&[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)], &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)],
); );
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(1).att_write); assert_eq!(
ATTR_WRITE_DEFAULT_VALUE,
handler.echo_cluster(1).att_write.get()
);
// Add ACL to allow our peer to access another endpoint // Add ACL to allow our peer to access another endpoint
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
acl.add_target(Target::new(Some(1), None, None)).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test 3: Wildcard write to attributes will return multiple attributes // Test 3: Wildcard write to attributes will return multiple attributes
// where the writes were successful // where the writes were successful
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input1, input1,
&[ &[
AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep0_att, IMStatusCode::Success, 0),
AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0),
], ],
); );
assert_eq!(val1, im.echo_cluster(0).att_write); assert_eq!(val1, handler.echo_cluster(0).att_write.get());
assert_eq!(val1, im.echo_cluster(1).att_write); assert_eq!(val1, handler.echo_cluster(1).att_write.get());
} }
#[test] #[test]
@ -253,25 +250,26 @@ fn exact_write_attribute() {
)]; )];
let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)];
let peer = 98765; let im = ImEngine::new_default();
let mut mdns = DummyMdns {}; let handler = im.handler();
let matter = matter(&mut mdns);
let mut im = ImEngine::new(&matter);
// Test 1: Exact write to an attribute without permission should return // Test 1: Exact write to an attribute without permission should return
// Unsupported Access Error // Unsupported Access Error
im.handle_write_reqs(peer, None, input, expected_fail); im.handle_write_reqs(&handler, input, expected_fail);
assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); assert_eq!(
ATTR_WRITE_DEFAULT_VALUE,
handler.echo_cluster(0).att_write.get()
);
// Add ACL to allow our peer to access any endpoint // Add ACL to allow our peer to access any endpoint
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test 1: Exact write to an attribute with permission should grant // Test 1: Exact write to an attribute with permission should grant
// access // access
im.handle_write_reqs(peer, None, input, expected_success); im.handle_write_reqs(&handler, input, expected_success);
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
} }
#[test] #[test]
@ -303,19 +301,20 @@ fn exact_write_attribute_noc_cat() {
)]; )];
let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)];
let peer = 98765;
/* CAT in NOC is 1 more, in version, than that in ACL */ /* CAT in NOC is 1 more, in version, than that in ACL */
let noc_cat = gen_noc_cat(0xABCD, 2); let noc_cat = gen_noc_cat(0xABCD, 2);
let cat_in_acl = gen_noc_cat(0xABCD, 1); let cat_in_acl = gen_noc_cat(0xABCD, 1);
let cat_ids = [noc_cat, 0, 0]; let cat_ids = [noc_cat, 0, 0];
let mut mdns = DummyMdns; let im = ImEngine::new(cat_ids);
let matter = matter(&mut mdns); let handler = im.handler();
let mut im = ImEngine::new(&matter);
// Test 1: Exact write to an attribute without permission should return // Test 1: Exact write to an attribute without permission should return
// Unsupported Access Error // Unsupported Access Error
im.handle_write_reqs(peer, Some(&cat_ids), input, expected_fail); im.handle_write_reqs(&handler, input, expected_fail);
assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); assert_eq!(
ATTR_WRITE_DEFAULT_VALUE,
handler.echo_cluster(0).att_write.get()
);
// Add ACL to allow our peer to access any endpoint // Add ACL to allow our peer to access any endpoint
let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
@ -324,8 +323,8 @@ fn exact_write_attribute_noc_cat() {
// Test 1: Exact write to an attribute with permission should grant // Test 1: Exact write to an attribute with permission should grant
// access // access
im.handle_write_reqs(peer, Some(&cat_ids), input, expected_success); im.handle_write_reqs(&handler, input, expected_success);
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
} }
#[test] #[test]
@ -347,21 +346,18 @@ fn insufficient_perms_write() {
EncodeValue::Closure(&attr_data0), EncodeValue::Closure(&attr_data0),
)]; )];
let peer = 98765; let im = ImEngine::new_default();
let mut mdns = DummyMdns {}; let handler = im.handler();
let matter = matter(&mut mdns);
let mut im = ImEngine::new(&matter);
// Add ACL to allow our peer with only OPERATE permission // Add ACL to allow our peer with only OPERATE permission
let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case);
acl.add_subject(peer).unwrap(); acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
acl.add_target(Target::new(Some(0), None, None)).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap();
im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap();
// Test: Not enough permission should return error // Test: Not enough permission should return error
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input0, input0,
&[AttrStatus::new( &[AttrStatus::new(
&ep0_att, &ep0_att,
@ -369,7 +365,10 @@ fn insufficient_perms_write() {
0, 0,
)], )],
); );
assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); assert_eq!(
ATTR_WRITE_DEFAULT_VALUE,
handler.echo_cluster(0).att_write.get()
);
} }
#[test] #[test]
@ -381,10 +380,9 @@ fn insufficient_perms_write() {
/// - Write Attr to Echo Cluster again (successful this time) /// - Write Attr to Echo Cluster again (successful this time)
fn write_with_runtime_acl_add() { fn write_with_runtime_acl_add() {
init_env_logger(); init_env_logger();
let peer = 98765;
let mut mdns = DummyMdns {}; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let mut im = ImEngine::new(&matter);
let val0 = 10; let val0 = 10;
let attr_data0 = |tag, t: &mut TLVWriter| { let attr_data0 = |tag, t: &mut TLVWriter| {
@ -403,7 +401,7 @@ fn write_with_runtime_acl_add() {
// Create ACL to allow our peer ADMIN on everything // Create ACL to allow our peer ADMIN on everything
let mut allow_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut allow_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
allow_acl.add_subject(peer).unwrap(); allow_acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
let acl_att = GenericPath::new( let acl_att = GenericPath::new(
Some(0), Some(0),
@ -418,7 +416,7 @@ fn write_with_runtime_acl_add() {
// Create ACL that only allows write to the ACL Cluster // Create ACL that only allows write to the ACL Cluster
let mut basic_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let mut basic_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
basic_acl.add_subject(peer).unwrap(); basic_acl.add_subject(IM_ENGINE_PEER_ID).unwrap();
basic_acl basic_acl
.add_target(Target::new(Some(0), Some(access_control::ID), None)) .add_target(Target::new(Some(0), Some(access_control::ID), None))
.unwrap(); .unwrap();
@ -426,8 +424,7 @@ fn write_with_runtime_acl_add() {
// Test: deny write (with error), then ACL is added, then allow write // Test: deny write (with error), then ACL is added, then allow write
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
// write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute
&[input0.clone(), acl_input, input0], &[input0.clone(), acl_input, input0],
&[ &[
@ -436,7 +433,7 @@ fn write_with_runtime_acl_add() {
AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep0_att, IMStatusCode::Success, 0),
], ],
); );
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
} }
#[test] #[test]
@ -448,10 +445,9 @@ fn test_read_data_ver() {
// - wildcard endpoint, att1 // - wildcard endpoint, att1
// - 2 responses are expected // - 2 responses are expected
init_env_logger(); init_env_logger();
let peer = 98765;
let mut mdns = DummyMdns {}; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let mut im = ImEngine::new(&matter);
// Add ACL to allow our peer with only OPERATE permission // Add ACL to allow our peer with only OPERATE permission
let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case);
@ -482,10 +478,11 @@ fn test_read_data_ver() {
ElementType::U16(0x1234) ElementType::U16(0x1234)
), ),
]; ];
let mut out_buf = [0u8; 400];
let mut out = heapless::Vec::new();
// Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0
let received = im.gen_read_reqs_output(peer, input, None, &mut out_buf); let received = im.gen_read_reqs_output::<1>(&handler, input, None, &mut out);
assert_attr_report(&received, expected); assert_attr_report(&received, expected);
let data_ver_cluster_at_0 = received let data_ver_cluster_at_0 = received
@ -507,11 +504,12 @@ fn test_read_data_ver() {
}]; }];
// Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved
let received = im.gen_read_reqs_output( let mut out = heapless::Vec::new();
peer, let received = im.gen_read_reqs_output::<1>(
&handler,
input, input,
Some(TLVArray::Slice(&dataver_filter)), Some(TLVArray::Slice(&dataver_filter)),
&mut out_buf, &mut out,
); );
let expected_only_one = &[attr_data_path!( let expected_only_one = &[attr_data_path!(
GenericPath::new( GenericPath::new(
@ -532,10 +530,10 @@ fn test_read_data_ver() {
); );
let input = &[AttrPath::new(&ep0_att1)]; let input = &[AttrPath::new(&ep0_att1)];
let received = im.gen_read_reqs_output( let received = im.gen_read_reqs_output(
peer, &handler,
input, input,
Some(TLVArray::Slice(&dataver_filter)), Some(TLVArray::Slice(&dataver_filter)),
&mut out_buf, &mut out,
); );
let expected_error = &[]; let expected_error = &[];
@ -551,10 +549,9 @@ fn test_write_data_ver() {
// - wildcard endpoint, att1 // - wildcard endpoint, att1
// - 2 responses are expected // - 2 responses are expected
init_env_logger(); init_env_logger();
let peer = 98765;
let mut mdns = DummyMdns {}; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let mut im = ImEngine::new(&matter);
// Add ACL to allow our peer with only OPERATE permission // Add ACL to allow our peer with only OPERATE permission
let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case);
@ -576,7 +573,7 @@ fn test_write_data_ver() {
let attr_data0 = EncodeValue::Value(&val0); let attr_data0 = EncodeValue::Value(&val0);
let attr_data1 = EncodeValue::Value(&val1); let attr_data1 = EncodeValue::Value(&val1);
let initial_data_ver = im.echo_cluster(0).data_ver.get(); let initial_data_ver = handler.echo_cluster(0).data_ver.get();
// Test 1: Write with correct dataversion should succeed // Test 1: Write with correct dataversion should succeed
let input_correct_dataver = &[AttrData::new( let input_correct_dataver = &[AttrData::new(
@ -585,12 +582,11 @@ fn test_write_data_ver() {
attr_data0, attr_data0,
)]; )];
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input_correct_dataver, input_correct_dataver,
&[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)],
); );
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
// Test 2: Write with incorrect dataversion should fail // Test 2: Write with incorrect dataversion should fail
// Now the data version would have incremented due to the previous write // Now the data version would have incremented due to the previous write
@ -600,8 +596,7 @@ fn test_write_data_ver() {
attr_data1.clone(), attr_data1.clone(),
)]; )];
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input_correct_dataver, input_correct_dataver,
&[AttrStatus::new( &[AttrStatus::new(
&ep0_attwrite, &ep0_attwrite,
@ -609,12 +604,12 @@ fn test_write_data_ver() {
0, 0,
)], )],
); );
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
// Test 3: Wildcard write with incorrect dataversion should ignore that cluster // Test 3: Wildcard write with incorrect dataversion should ignore that cluster
// In this case, while the data version is correct for endpoint 0, the endpoint 1's // In this case, while the data version is correct for endpoint 0, the endpoint 1's
// data version would not match // data version would not match
let new_data_ver = im.echo_cluster(0).data_ver.get(); let new_data_ver = handler.echo_cluster(0).data_ver.get();
let input_correct_dataver = &[AttrData::new( let input_correct_dataver = &[AttrData::new(
Some(new_data_ver), Some(new_data_ver),
@ -622,12 +617,11 @@ fn test_write_data_ver() {
attr_data1, attr_data1,
)]; )];
im.handle_write_reqs( im.handle_write_reqs(
peer, &handler,
None,
input_correct_dataver, input_correct_dataver,
&[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)],
); );
assert_eq!(val1, im.echo_cluster(0).att_write); assert_eq!(val1, handler.echo_cluster(0).att_write.get());
assert_eq!(initial_data_ver + 1, new_data_ver); assert_eq!(initial_data_ver + 1, new_data_ver);
} }

View file

@ -22,13 +22,12 @@ use matter::{
messages::ib::{AttrData, AttrPath, AttrStatus}, messages::ib::{AttrData, AttrPath, AttrStatus},
messages::GenericPath, messages::GenericPath,
}, },
mdns::DummyMdns,
tlv::Nullable, tlv::Nullable,
}; };
use crate::common::{ use crate::common::{
echo_cluster::{self, TestChecker}, echo_cluster::{self, TestChecker},
im_engine::{matter, ImEngine}, im_engine::ImEngine,
init_env_logger, init_env_logger,
}; };
@ -65,8 +64,8 @@ fn attr_list_ops() {
EncodeValue::Value(&val0), EncodeValue::Value(&val0),
)]; )];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([Some(val0), None, None, None, None], tc.write_list); assert_eq!([Some(val0), None, None, None, None], tc.write_list);
@ -79,8 +78,8 @@ fn attr_list_ops() {
EncodeValue::Value(&val1), EncodeValue::Value(&val1),
)]; )];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([Some(val0), Some(val1), None, None, None], tc.write_list); assert_eq!([Some(val0), Some(val1), None, None, None], tc.write_list);
@ -94,8 +93,8 @@ fn attr_list_ops() {
EncodeValue::Value(&val0), EncodeValue::Value(&val0),
)]; )];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([Some(val0), Some(val0), None, None, None], tc.write_list); assert_eq!([Some(val0), Some(val0), None, None, None], tc.write_list);
@ -105,8 +104,8 @@ fn attr_list_ops() {
att_path.list_index = Some(Nullable::NotNull(0)); att_path.list_index = Some(Nullable::NotNull(0));
let input = &[AttrData::new(None, att_path.clone(), delete_item)]; let input = &[AttrData::new(None, att_path.clone(), delete_item)];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([None, Some(val0), None, None, None], tc.write_list); assert_eq!([None, Some(val0), None, None, None], tc.write_list);
@ -121,8 +120,8 @@ fn attr_list_ops() {
EncodeValue::Value(&overwrite_val), EncodeValue::Value(&overwrite_val),
)]; )];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([Some(20), Some(21), None, None, None], tc.write_list); assert_eq!([Some(20), Some(21), None, None, None], tc.write_list);
@ -132,8 +131,8 @@ fn attr_list_ops() {
att_path.list_index = None; att_path.list_index = None;
let input = &[AttrData::new(None, att_path, delete_all)]; let input = &[AttrData::new(None, att_path, delete_all)];
let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)];
ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected);
ImEngine::write_reqs(input, expected);
{ {
let tc = tc_handle.lock().unwrap(); let tc = tc_handle.lock().unwrap();
assert_eq!([None, None, None, None, None], tc.write_list); assert_eq!([None, None, None, None, None], tc.write_list);

View file

@ -25,18 +25,12 @@ use matter::{
messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus},
messages::GenericPath, messages::GenericPath,
}, },
mdns::DummyMdns,
tlv::{ElementType, TLVElement, TLVWriter, TagType}, tlv::{ElementType, TLVElement, TLVWriter, TagType},
}; };
use crate::{ use crate::{
attr_data, attr_data_path, attr_status, attr_data, attr_data_path, attr_status,
common::{ common::{attributes::*, echo_cluster, im_engine::ImEngine, init_env_logger},
attributes::*,
echo_cluster,
im_engine::{matter, ImEngine},
init_env_logger,
},
}; };
#[test] #[test]
@ -75,7 +69,7 @@ fn test_read_success() {
ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE)
), ),
]; ];
ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); ImEngine::read_reqs(input, expected);
} }
#[test] #[test]
@ -122,7 +116,7 @@ fn test_read_unsupported_fields() {
attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster), attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster),
attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute), attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute),
]; ];
ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); ImEngine::read_reqs(input, expected);
} }
#[test] #[test]
@ -153,7 +147,7 @@ fn test_read_wc_endpoint_all_have_clusters() {
ElementType::U16(0x1234) ElementType::U16(0x1234)
), ),
]; ];
ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); ImEngine::read_reqs(input, expected);
} }
#[test] #[test]
@ -178,7 +172,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() {
), ),
ElementType::False ElementType::False
)]; )];
ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); ImEngine::read_reqs(input, expected);
} }
#[test] #[test]
@ -285,7 +279,7 @@ fn test_read_wc_endpoint_wc_attribute() {
ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE)
), ),
]; ];
ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); ImEngine::read_reqs(input, expected);
} }
#[test] #[test]
@ -331,11 +325,14 @@ fn test_write_success() {
AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0),
]; ];
let mut mdns = DummyMdns; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let im = ImEngine::new_with_write_reqs(&matter, input, expected);
assert_eq!(val0, im.echo_cluster(0).att_write); im.add_default_acl();
assert_eq!(val1, im.echo_cluster(1).att_write); im.handle_write_reqs(&handler, input, expected);
assert_eq!(val0, handler.echo_cluster(0).att_write.get());
assert_eq!(val1, handler.echo_cluster(1).att_write.get());
} }
#[test] #[test]
@ -375,10 +372,13 @@ fn test_write_wc_endpoint() {
AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0),
]; ];
let mut mdns = DummyMdns; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let im = ImEngine::new_with_write_reqs(&matter, input, expected);
assert_eq!(val0, im.echo_cluster(0).att_write); im.add_default_acl();
im.handle_write_reqs(&handler, input, expected);
assert_eq!(val0, handler.echo_cluster(0).att_write.get());
} }
#[test] #[test]
@ -467,11 +467,14 @@ fn test_write_unsupported_fields() {
AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0), AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0),
AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0), AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0),
]; ];
let mut mdns = DummyMdns; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let im = ImEngine::new_with_write_reqs(&matter, input, expected);
im.add_default_acl();
im.handle_write_reqs(&handler, input, expected);
assert_eq!( assert_eq!(
echo_cluster::ATTR_WRITE_DEFAULT_VALUE, echo_cluster::ATTR_WRITE_DEFAULT_VALUE,
im.echo_cluster(0).att_write handler.echo_cluster(0).att_write.get()
); );
} }

View file

@ -17,12 +17,7 @@
use crate::{ use crate::{
cmd_data, cmd_data,
common::{ common::{commands::*, echo_cluster, im_engine::ImEngine, init_env_logger},
commands::*,
echo_cluster,
im_engine::{matter, ImEngine},
init_env_logger,
},
echo_req, echo_resp, echo_req, echo_resp,
}; };
@ -32,7 +27,6 @@ use matter::{
core::IMStatusCode, core::IMStatusCode,
messages::ib::{CmdData, CmdPath, CmdStatus}, messages::ib::{CmdData, CmdPath, CmdStatus},
}, },
mdns::DummyMdns,
}; };
#[test] #[test]
@ -44,7 +38,7 @@ fn test_invoke_cmds_success() {
let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let input = &[echo_req!(0, 5), echo_req!(1, 10)];
let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)];
ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); ImEngine::commands(input, expected);
} }
#[test] #[test]
@ -99,7 +93,7 @@ fn test_invoke_cmds_unsupported_fields() {
0, 0,
)), )),
]; ];
ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); ImEngine::commands(input, expected);
} }
#[test] #[test]
@ -114,7 +108,7 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() {
); );
let input = &[cmd_data!(path, 5)]; let input = &[cmd_data!(path, 5)];
let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)];
ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); ImEngine::commands(input, expected);
} }
#[test] #[test]
@ -139,5 +133,5 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() {
IMStatusCode::Success, IMStatusCode::Success,
0, 0,
))]; ))];
ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); ImEngine::commands(input, expected);
} }

View file

@ -30,13 +30,7 @@ use matter::{
}, },
messages::{msg::SubscribeReq, GenericPath}, messages::{msg::SubscribeReq, GenericPath},
}, },
mdns::DummyMdns, tlv::{self, ElementType, FromTLV, TLVElement, TagType},
tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV},
transport::{
exchange::{self, Exchange},
packet::MAX_RX_BUF_SIZE,
},
Matter,
}; };
use crate::{ use crate::{
@ -44,35 +38,11 @@ use crate::{
common::{ common::{
attributes::*, attributes::*,
echo_cluster as echo, echo_cluster as echo,
im_engine::{matter, ImEngine, ImInput}, im_engine::{ImEngine, ImInput},
init_env_logger, init_env_logger,
}, },
}; };
pub struct LongRead<'a> {
im_engine: ImEngine<'a>,
}
impl<'a> LongRead<'a> {
pub fn new(matter: &'a Matter<'a>) -> Self {
let mut im_engine = ImEngine::new(matter);
// Use the same exchange for all parts of the transaction
im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder));
Self { im_engine }
}
pub fn process<'p>(
&mut self,
action: OpCode,
data: &dyn ToTLV,
data_out: &'p mut [u8],
) -> (u8, &'p [u8]) {
let input = ImInput::new(action, data);
let (response, output) = self.im_engine.process(&input, data_out);
(response, output)
}
}
fn wildcard_read_resp(part: u8) -> Vec<AttrResp<'static>> { fn wildcard_read_resp(part: u8) -> Vec<AttrResp<'static>> {
// For brevity, we only check the AttrPath, not the actual 'data' // For brevity, we only check the AttrPath, not the actual 'data'
let dont_care = ElementType::U8(0); let dont_care = ElementType::U8(0);
@ -215,6 +185,9 @@ fn wildcard_read_resp(part: u8) -> Vec<AttrResp<'static>> {
acl::AttributesDiscriminants::Extension, acl::AttributesDiscriminants::Extension,
dont_care.clone() dont_care.clone()
), ),
];
let part2 = vec![
attr_data!( attr_data!(
0, 0,
31, 31,
@ -266,9 +239,6 @@ fn wildcard_read_resp(part: u8) -> Vec<AttrResp<'static>> {
descriptor::Attributes::DeviceTypeList, descriptor::Attributes::DeviceTypeList,
dont_care.clone() dont_care.clone()
), ),
];
let part2 = vec![
attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()),
attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()),
attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()),
@ -318,74 +288,103 @@ fn wildcard_read_resp(part: u8) -> Vec<AttrResp<'static>> {
fn test_long_read_success() { fn test_long_read_success() {
// Read the entire attribute database, which requires 2 reads to complete // Read the entire attribute database, which requires 2 reads to complete
init_env_logger(); init_env_logger();
let mut mdns = DummyMdns;
let matter = matter(&mut mdns); let mut out = heapless::Vec::<_, 3>::new();
let mut lr = LongRead::new(&matter); let im = ImEngine::new_default();
let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let handler = im.handler();
im.add_default_acl();
let wc_path = GenericPath::new(None, None, None); let wc_path = GenericPath::new(None, None, None);
let read_all = [AttrPath::new(&wc_path)]; let read_all = [AttrPath::new(&wc_path)];
let read_req = ReadReq::new(true).set_attr_requests(&read_all); let read_req = ReadReq::new(true).set_attr_requests(&read_all);
let expected_part1 = wildcard_read_resp(1); let expected_part1 = wildcard_read_resp(1);
let (out_code, out_data) = lr.process(OpCode::ReadRequest, &read_req, &mut output);
let root = tlv::get_root_node_struct(out_data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part1);
assert_eq!(report_data.more_chunks, Some(true));
assert_eq!(out_code, OpCode::ReportData as u8);
// Ask for the next read by sending a status report
let status_report = StatusResp { let status_report = StatusResp {
status: IMStatusCode::Success, status: IMStatusCode::Success,
}; };
let expected_part2 = wildcard_read_resp(2); let expected_part2 = wildcard_read_resp(2);
let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output);
let root = tlv::get_root_node_struct(out_data).unwrap(); im.process(
&handler,
&[
&ImInput::new(OpCode::ReadRequest, &read_req),
&ImInput::new(OpCode::StatusResponse, &status_report),
],
&mut out,
)
.unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].action, OpCode::ReportData);
let root = tlv::get_root_node_struct(&out[0].data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part1);
assert_eq!(report_data.more_chunks, Some(true));
assert_eq!(out[1].action, OpCode::ReportData);
let root = tlv::get_root_node_struct(&out[1].data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part2); assert_attr_report_skip_data(&report_data, &expected_part2);
assert_eq!(report_data.more_chunks, None); assert_eq!(report_data.more_chunks, None);
assert_eq!(out_code, OpCode::ReportData as u8);
} }
#[test] #[test]
fn test_long_read_subscription_success() { fn test_long_read_subscription_success() {
// Subscribe to the entire attribute database, which requires 2 reads to complete // Subscribe to the entire attribute database, which requires 2 reads to complete
init_env_logger(); init_env_logger();
let mut mdns = DummyMdns;
let matter = matter(&mut mdns); let mut out = heapless::Vec::<_, 3>::new();
let mut lr = LongRead::new(&matter); let im = ImEngine::new_default();
let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let handler = im.handler();
im.add_default_acl();
let wc_path = GenericPath::new(None, None, None); let wc_path = GenericPath::new(None, None, None);
let read_all = [AttrPath::new(&wc_path)]; let read_all = [AttrPath::new(&wc_path)];
let subs_req = SubscribeReq::new(true, 1, 20).set_attr_requests(&read_all); let subs_req = SubscribeReq::new(true, 1, 20).set_attr_requests(&read_all);
let expected_part1 = wildcard_read_resp(1); let expected_part1 = wildcard_read_resp(1);
let (out_code, out_data) = lr.process(OpCode::SubscribeRequest, &subs_req, &mut output);
let root = tlv::get_root_node_struct(out_data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part1);
assert_eq!(report_data.more_chunks, Some(true));
assert_eq!(out_code, OpCode::ReportData as u8);
// Ask for the next read by sending a status report
let status_report = StatusResp { let status_report = StatusResp {
status: IMStatusCode::Success, status: IMStatusCode::Success,
}; };
let expected_part2 = wildcard_read_resp(2); let expected_part2 = wildcard_read_resp(2);
let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output);
let root = tlv::get_root_node_struct(out_data).unwrap(); im.process(
&handler,
&[
&ImInput::new(OpCode::SubscribeRequest, &subs_req),
&ImInput::new(OpCode::StatusResponse, &status_report),
&ImInput::new(OpCode::StatusResponse, &status_report),
],
&mut out,
)
.unwrap();
assert_eq!(out.len(), 3);
assert_eq!(out[0].action, OpCode::ReportData);
let root = tlv::get_root_node_struct(&out[0].data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part1);
assert_eq!(report_data.more_chunks, Some(true));
assert_eq!(out[1].action, OpCode::ReportData);
let root = tlv::get_root_node_struct(&out[1].data).unwrap();
let report_data = ReportDataMsg::from_tlv(&root).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap();
assert_attr_report_skip_data(&report_data, &expected_part2); assert_attr_report_skip_data(&report_data, &expected_part2);
assert_eq!(report_data.more_chunks, None); assert_eq!(report_data.more_chunks, None);
assert_eq!(out_code, OpCode::ReportData as u8);
// Finally confirm subscription assert_eq!(out[2].action, OpCode::SubscribeResponse);
let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output);
tlv::print_tlv_list(out_data); let root = tlv::get_root_node_struct(&out[2].data).unwrap();
let root = tlv::get_root_node_struct(out_data).unwrap();
let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); let subs_resp = SubscribeResp::from_tlv(&root).unwrap();
assert_eq!(out_code, OpCode::SubscribeResponse as u8);
assert_eq!(subs_resp.subs_id, 1); assert_eq!(subs_resp.subs_id, 1);
} }

View file

@ -22,7 +22,6 @@ use matter::{
messages::ib::{AttrData, AttrPath, AttrStatus}, messages::ib::{AttrData, AttrPath, AttrStatus},
messages::{ib::CmdData, ib::CmdPath, GenericPath}, messages::{ib::CmdData, ib::CmdPath, GenericPath},
}, },
mdns::DummyMdns,
tlv::TLVWriter, tlv::TLVWriter,
}; };
@ -31,7 +30,7 @@ use crate::{
commands::*, commands::*,
echo_cluster, echo_cluster,
handlers::{TimedInvResponse, WriteResponse}, handlers::{TimedInvResponse, WriteResponse},
im_engine::{matter, ImEngine}, im_engine::ImEngine,
init_env_logger, init_env_logger,
}, },
echo_req, echo_resp, echo_req, echo_resp,
@ -75,25 +74,20 @@ fn test_timed_write_fail_and_success() {
]; ];
// Test with incorrect handling // Test with incorrect handling
ImEngine::new_with_timed_write_reqs( ImEngine::timed_write_reqs(input, &WriteResponse::TransactionError, 100, 500);
&matter(&mut DummyMdns),
input,
&WriteResponse::TransactionError,
400,
500,
);
// Test with correct handling // Test with correct handling
let mut mdns = DummyMdns; let im = ImEngine::new_default();
let matter = matter(&mut mdns); let handler = im.handler();
let im = ImEngine::new_with_timed_write_reqs( im.add_default_acl();
&matter, im.handle_timed_write_reqs(
&handler,
input, input,
&WriteResponse::TransactionSuccess(expected), &WriteResponse::TransactionSuccess(expected),
400, 400,
0, 0,
); );
assert_eq!(val0, im.echo_cluster(0).att_write); assert_eq!(val0, handler.echo_cluster(0).att_write.get());
} }
#[test] #[test]
@ -103,8 +97,7 @@ fn test_timed_cmd_success() {
let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let input = &[echo_req!(0, 5), echo_req!(1, 10)];
let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)];
ImEngine::new_with_timed_commands( ImEngine::timed_commands(
&matter(&mut DummyMdns),
input, input,
&TimedInvResponse::TransactionSuccess(expected), &TimedInvResponse::TransactionSuccess(expected),
400, 400,
@ -119,11 +112,10 @@ fn test_timed_cmd_timeout() {
init_env_logger(); init_env_logger();
let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let input = &[echo_req!(0, 5), echo_req!(1, 10)];
ImEngine::new_with_timed_commands( ImEngine::timed_commands(
&matter(&mut DummyMdns),
input, input,
&TimedInvResponse::TransactionError(IMStatusCode::Timeout), &TimedInvResponse::TransactionError(IMStatusCode::Timeout),
400, 100,
500, 500,
true, true,
); );
@ -135,8 +127,7 @@ fn test_timed_cmd_timedout_mismatch() {
init_env_logger(); init_env_logger();
let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let input = &[echo_req!(0, 5), echo_req!(1, 10)];
ImEngine::new_with_timed_commands( ImEngine::timed_commands(
&matter(&mut DummyMdns),
input, input,
&TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch),
400, 400,
@ -145,8 +136,7 @@ fn test_timed_cmd_timedout_mismatch() {
); );
let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let input = &[echo_req!(0, 5), echo_req!(1, 10)];
ImEngine::new_with_timed_commands( ImEngine::timed_commands(
&matter(&mut DummyMdns),
input, input,
&TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch),
0, 0,

View file

@ -1,152 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use matter::data_model::core::DataHandler;
use matter::error::Error;
use matter::interaction_model::core::Interaction;
use matter::interaction_model::core::InteractionModel;
use matter::interaction_model::core::OpCode;
use matter::interaction_model::core::Transaction;
use matter::transport::exchange::Exchange;
use matter::transport::exchange::ExchangeCtx;
use matter::transport::network::Address;
use matter::transport::network::IpAddr;
use matter::transport::network::Ipv4Addr;
use matter::transport::network::SocketAddr;
use matter::transport::packet::Packet;
use matter::transport::packet::MAX_RX_BUF_SIZE;
use matter::transport::packet::MAX_TX_BUF_SIZE;
use matter::transport::proto_ctx::ProtoCtx;
use matter::transport::session::SessionMgr;
use matter::utils::epoch::dummy_epoch;
use matter::utils::rand::dummy_rand;
struct Node {
pub endpoint: u16,
pub cluster: u32,
pub command: u16,
pub variable: u8,
}
struct DataModel {
node: Node,
}
impl DataModel {
pub fn new(node: Node) -> Self {
DataModel { node }
}
}
impl DataHandler for DataModel {
fn handle(
&mut self,
interaction: Interaction,
_tx: &mut Packet,
_transaction: &mut Transaction,
) -> Result<bool, Error> {
if let Interaction::Invoke(req) = interaction {
if let Some(inv_requests) = &req.inv_requests {
for i in inv_requests.iter() {
let data = if let Some(data) = i.data.unwrap_tlv() {
data
} else {
continue;
};
let cmd_path_ib = i.path;
let common_data = &mut self.node;
common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1);
common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0);
common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16;
data.confirm_struct().unwrap();
common_data.variable = data.find_tag(0).unwrap().u8().unwrap();
}
}
}
Ok(false)
}
}
fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataModel, usize) {
let data_model = DataModel::new(Node {
endpoint: 0,
cluster: 0,
command: 0,
variable: 0,
});
let mut interaction_model = InteractionModel(data_model);
let mut exch: Exchange = Default::default();
let mut sess_mgr = SessionMgr::new(dummy_epoch, dummy_rand);
let sess_idx = sess_mgr
.get_or_add(
0,
Address::Udp(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
5542,
)),
None,
false,
)
.unwrap();
let sess = sess_mgr.get_session_handle(sess_idx);
let exch_ctx = ExchangeCtx {
exch: &mut exch,
sess,
epoch: dummy_epoch,
};
let mut rx_buf = [0; MAX_RX_BUF_SIZE];
let mut tx_buf = [0; MAX_TX_BUF_SIZE];
let mut rx = Packet::new_rx(&mut rx_buf);
let mut tx = Packet::new_tx(&mut tx_buf);
// Create fake rx packet
rx.set_proto_id(0x01);
rx.set_proto_opcode(action as u8);
rx.peer = Address::default();
let in_data_len = data_in.len();
let rx_buf = rx.as_mut_slice();
rx_buf[..in_data_len].copy_from_slice(data_in);
let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx);
interaction_model.handle(&mut ctx).unwrap();
let out_len = ctx.tx.as_mut_slice().len();
data_out[..out_len].copy_from_slice(ctx.tx.as_mut_slice());
(interaction_model.0, out_len)
}
#[test]
fn test_valid_invoke_cmd() -> Result<(), Error> {
// An invoke command for endpoint 0, cluster 49, command 12 and a u8 variable value of 0x05
let b = [
0x15, 0x28, 0x00, 0x28, 0x01, 0x36, 0x02, 0x15, 0x37, 0x00, 0x25, 0x00, 0x00, 0x00, 0x26,
0x01, 0x31, 0x00, 0x00, 0x00, 0x26, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x18, 0x35, 0x01, 0x24,
0x00, 0x05, 0x18, 0x18, 0x18, 0x18,
];
let mut out_buf: [u8; 20] = [0; 20];
let (data_model, _) = handle_data(OpCode::InvokeRequest, &b, &mut out_buf);
let data = &data_model.node;
assert_eq!(data.endpoint, 0);
assert_eq!(data.cluster, 49);
assert_eq!(data.command, 12);
assert_eq!(data.variable, 5);
Ok(())
}

6
sdkconfig.defaults Normal file
View file

@ -0,0 +1,6 @@
# Workaround for https://github.com/espressif/esp-idf/issues/7631
CONFIG_MBEDTLS_CERTIFICATE_BUNDLE=n
CONFIG_MBEDTLS_CERTIFICATE_BUNDLE_DEFAULT_FULL=n
# Examples often require a larger than the default stack size for the main thread.
CONFIG_ESP_MAIN_TASK_STACK_SIZE=10000