diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index a5340f6..ecfc71e 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -23,20 +23,15 @@ use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_on_off; -use matter::data_model::core::DataModel; use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT; use matter::data_model::objects::*; use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; -use matter::interaction_model::core::InteractionModel; use matter::mdns::{DefaultMdns, DefaultMdnsRunner}; use matter::secure_channel::spake2p::VerifierData; -use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use matter::transport::{ - core::RecvAction, core::Transport, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, - udp::UdpListener, -}; +use matter::transport::network::{Ipv4Addr, Ipv6Addr}; +use matter::transport::runner::{RxBuf, TransportRunner, TxBuf}; use matter::utils::select::EitherUnwrap; mod dev_att; @@ -44,7 +39,7 @@ mod dev_att; #[cfg(feature = "std")] fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() - .stack_size(120 * 1024) + .stack_size(140 * 1024) .spawn(run) .unwrap(); @@ -62,10 +57,10 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Matter={}, Transport={}", + "Matter memory: mDNS={}, Matter={}, TransportRunner={}", core::mem::size_of::(), core::mem::size_of::(), - core::mem::size_of::(), + core::mem::size_of::(), ); let dev_det = BasicInfoConfig { @@ -92,6 +87,8 @@ fn run() -> Result<(), Error> { let mut mdns_runner = DefaultMdnsRunner::new(&mdns); + info!("mDNS initialized: {:p}, {:p}", &mdns, &mdns_runner); + let dev_att = dev_att::HardCodedDevAtt::new(); #[cfg(feature = "std")] @@ -118,36 +115,25 @@ fn run() -> Result<(), Error> { matter::MATTER_PORT, ); - let psm_path = std::env::temp_dir().join("matter-iot"); - info!("Persisting from/to {}", psm_path.display()); + info!("Matter initialized: {:p}", &matter); - #[cfg(all(feature = "std", not(target_os = "espidf")))] - let psm = matter::persist::FilePsm::new(psm_path)?; + let mut runner = TransportRunner::new(&matter); - let mut buf = [0; 4096]; - let buf = &mut buf; + info!("Transport Runner initialized: {:p}", &runner); - #[cfg(all(feature = "std", not(target_os = "espidf")))] - { - if let Some(data) = psm.load("acls", buf)? { - matter.load_acls(data)?; - } + let mut tx_buf = TxBuf::uninit(); + let mut rx_buf = RxBuf::uninit(); - if let Some(data) = psm.load("fabrics", buf)? { - matter.load_fabrics(data)?; - } - } + // #[cfg(all(feature = "std", not(target_os = "espidf")))] + // { + // if let Some(data) = psm.load("acls", buf)? { + // matter.load_acls(data)?; + // } - let mut transport = Transport::new(&matter); - - transport.start( - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *matter.borrow()), - discriminator: 250, - }, - buf, - )?; + // if let Some(data) = psm.load("fabrics", buf)? { + // matter.load_fabrics(data)?; + // } + // } let node = Node { id: 0, @@ -161,69 +147,48 @@ fn run() -> Result<(), Error> { ], }; - let mut handler = handler(&matter); + let handler = HandlerCompat(handler(&matter)); - let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - - let im = &mut im; - let mdns_runner = &mut mdns_runner; - let transport = &mut transport; - let rx_buf = &mut rx_buf; + let matter = &matter; + let node = &node; + let handler = &handler; + let runner = &mut runner; let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; - let mut io_fut = pin!(async move { - // NOTE (no_std): On no_std, the `UdpListener` implementation is a no-op so you might want to - // replace it with your own UDP stack - let udp = UdpListener::new(SocketAddr::new( - IpAddr::V6(Ipv6Addr::UNSPECIFIED), - matter::MATTER_PORT, - )) - .await?; + info!( + "About to run wth node {:p}, handler {:p}, transport runner {:p}, mdns_runner {:p}", + node, handler, runner, &mdns_runner + ); - loop { - let (len, addr) = udp.recv(rx_buf).await?; + let mut fut = pin!(async move { + // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and + // connect the pipes of the `run` method with your own UDP stack + let mut transport = pin!(runner.run_udp( + tx_buf, + rx_buf, + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &handler, + )); - let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf); + // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and + // connect the pipes of the `run` method with your own UDP stack + let mut mdns = pin!(mdns_runner.run_udp()); - while let Some(action) = completion.next_action()? { - match action { - RecvAction::Send(addr, buf) => { - udp.send(addr.unwrap_udp(), buf).await?; - } - RecvAction::Interact(mut ctx) => { - if im.handle(&mut ctx)? && ctx.send()? { - udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) - .await?; - } - } - } - } - - #[cfg(all(feature = "std", not(target_os = "espidf")))] - { - if let Some(data) = transport.matter().store_fabrics(buf)? { - psm.store("fabrics", data)?; - } - - if let Some(data) = transport.matter().store_acls(buf)? { - psm.store("acls", data)?; - } - } - } - - #[allow(unreachable_code)] - Ok::<_, matter::error::Error>(()) + select( + &mut transport, + &mut mdns, + //save(transport, &psm), + ) + .await + .unwrap() }); - // NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and - // connect the pipes of the `run` method with your own UDP stack - let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await }); - - let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() }); - + // NOTE: For no_std, replace with your own no_std way of polling the future #[cfg(feature = "std")] smol::block_on(&mut fut)?; @@ -235,18 +200,33 @@ fn run() -> Result<(), Error> { Ok(()) } -fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a { - root_endpoint::handler(0, matter) - .chain( - 1, - descriptor::ID, - descriptor::DescriptorCluster::new(*matter.borrow()), - ) - .chain( - 1, - cluster_on_off::ID, - cluster_on_off::OnOffCluster::new(*matter.borrow()), - ) +const NODE: Node<'static> = Node { + id: 0, + endpoints: &[ + root_endpoint::endpoint(0), + Endpoint { + id: 1, + device_type: DEV_TYPE_ON_OFF_LIGHT, + clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], + }, + ], +}; + +fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + 'a { + ( + NODE, + root_endpoint::handler(0, matter) + .chain( + 1, + descriptor::ID, + descriptor::DescriptorCluster::new(*matter.borrow()), + ) + .chain( + 1, + cluster_on_off::ID, + cluster_on_off::OnOffCluster::new(*matter.borrow()), + ), + ) } // NOTE (no_std): For no_std, implement here your own way of initializing the logger diff --git a/matter/src/data_model/cluster_on_off.rs b/matter/src/data_model/cluster_on_off.rs index 1a26522..8d03d9b 100644 --- a/matter/src/data_model/cluster_on_off.rs +++ b/matter/src/data_model/cluster_on_off.rs @@ -15,12 +15,12 @@ * limitations under the License. */ -use core::convert::TryInto; +use core::{cell::Cell, convert::TryInto}; use super::objects::*; use crate::{ - attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction, - tlv::TLVElement, utils::rand::Rand, + attribute_enum, cmd_enter, command_enum, error::Error, tlv::TLVElement, + transport::exchange::Exchange, utils::rand::Rand, }; use log::info; use strum::{EnumDiscriminants, FromRepr}; @@ -66,20 +66,20 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct OnOffCluster { data_ver: Dataver, - on: bool, + on: Cell, } impl OnOffCluster { pub fn new(rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), - on: false, + on: Cell::new(false), } } - pub fn set(&mut self, on: bool) { - if self.on != on { - self.on = on; + pub fn set(&self, on: bool) { + if self.on.get() != on { + self.on.set(on); self.data_ver.changed(); } } @@ -90,7 +90,7 @@ impl OnOffCluster { CLUSTER.read(attr.attr_id, writer) } else { match attr.attr_id.try_into()? { - Attributes::OnOff(codec) => codec.encode(writer, self.on), + Attributes::OnOff(codec) => codec.encode(writer, self.on.get()), } } } else { @@ -98,7 +98,7 @@ impl OnOffCluster { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { let data = data.with_dataver(self.data_ver.get())?; match attr.attr_id.try_into()? { @@ -111,8 +111,8 @@ impl OnOffCluster { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, _data: &TLVElement, _encoder: CmdDataEncoder, @@ -128,12 +128,10 @@ impl OnOffCluster { } Commands::Toggle => { cmd_enter!("Toggle"); - self.set(!self.on); + self.set(!self.on.get()); } } - transaction.complete(); - self.data_ver.changed(); Ok(()) @@ -145,18 +143,18 @@ impl Handler for OnOffCluster { OnOffCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { OnOffCluster::write(self, attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - OnOffCluster::invoke(self, transaction, cmd, data, encoder) + OnOffCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 20efeb7..69935c5 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,287 +15,127 @@ * limitations under the License. */ -use core::cell::RefCell; +use core::sync::atomic::{AtomicU32, Ordering}; use super::objects::*; use crate::{ - acl::{Accessor, AclMgr}, + alloc, error::*, - interaction_model::core::{Interaction, Transaction}, - tlv::TLVWriter, - transport::packet::Packet, + interaction_model::core::Interaction, + transport::{exchange::Exchange, packet::Packet}, }; -pub struct DataModel<'a, T> { - pub acl_mgr: &'a RefCell, - pub node: &'a Node<'a>, - pub handler: T, -} +// TODO: For now... +static SUBS_ID: AtomicU32 = AtomicU32::new(1); -impl<'a, T> DataModel<'a, T> { - pub const fn new(acl_mgr: &'a RefCell, node: &'a Node<'a>, handler: T) -> Self { - Self { - acl_mgr, - node, - handler, - } +pub struct DataModel(T); + +impl DataModel { + pub fn new(handler: T) -> Self { + Self(handler) } - pub fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result + pub async fn handle<'r, 'p>( + &self, + exchange: &'r mut Exchange<'_>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + rx_status: &'r mut Packet<'p>, + ) -> Result<(), Error> where - T: Handler, + T: DataModelHandler, { - let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); - let mut tw = TLVWriter::new(tx.get_writebuf()?); + let timeout = Interaction::timeout(exchange, rx, tx).await?; - match interaction { - Interaction::Read(req) => { - let mut resume_path = None; + let mut interaction = alloc!(Interaction::new( + exchange, + rx, + tx, + rx_status, + || SUBS_ID.fetch_add(1, Ordering::SeqCst), + timeout, + )?); - for item in self.node.read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + #[cfg(feature = "alloc")] + let interaction = &mut *interaction; + + #[cfg(not(feature = "alloc"))] + let interaction = &mut interaction; + + #[cfg(feature = "nightly")] + let metadata = self.0.lock().await; + + #[cfg(not(feature = "nightly"))] + let metadata = self.0.lock(); + + if interaction.start().await? { + match interaction { + Interaction::Read { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; + + 'outer: for item in metadata.node().read(req, None, &accessor) { + while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) + .await? + { + if !driver.send_chunk(req).await? { + break 'outer; + } + } } + + driver.complete(req).await?; } + Interaction::Write { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; - req.complete(tx, transaction, resume_path) - } - Interaction::Write(req) => { - for item in self.node.write(&req, &accessor) { - AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; - } - - req.complete(tx, transaction) - } - Interaction::Invoke(req) => { - for item in self.node.invoke(&req, &accessor) { - CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; - } - - req.complete(tx, transaction) - } - Interaction::Subscribe(req) => { - let mut resume_path = None; - - for item in self.node.subscribing_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + for item in metadata.node().write(req, &accessor) { + AttrDataEncoder::handle_write(&item, &self.0, &mut driver.writer()?) + .await?; } + + driver.complete(req).await?; } + Interaction::Invoke { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; - req.complete(tx, transaction, resume_path) - } - Interaction::Timed(_) => Ok(false), - Interaction::ResumeRead(req) => { - let mut resume_path = None; + for item in metadata.node().invoke(req, &accessor) { + let (mut tw, exchange) = driver.writer_exchange()?; - for item in self.node.resume_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + CmdDataEncoder::handle(&item, &self.0, &mut tw, exchange).await?; } + + driver.complete(req).await?; } + Interaction::Subscribe { + req, + ref mut driver, + } => { + let accessor = driver.accessor()?; - req.complete(tx, transaction, resume_path) - } - Interaction::ResumeSubscribe(req) => { - let mut resume_path = None; - - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? - { - resume_path = Some(path); - break; + 'outer: for item in metadata.node().subscribing_read(req, None, &accessor) { + while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?) + .await? + { + if !driver.send_chunk(req).await? { + break 'outer; + } + } } - } - req.complete(tx, transaction, resume_path) + driver.complete(req).await?; + } } } - } - #[cfg(feature = "nightly")] - pub async fn handle_async<'p>( - &mut self, - interaction: Interaction<'_>, - tx: &'p mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result - where - T: super::objects::asynch::AsyncHandler, - { - let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - match interaction { - Interaction::Read(req) => { - let mut resume_path = None; - - for item in self.node.read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::Write(req) => { - for item in self.node.write(&req, &accessor) { - AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; - } - - req.complete(tx, transaction) - } - Interaction::Invoke(req) => { - for item in self.node.invoke(&req, &accessor) { - CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) - .await?; - } - - req.complete(tx, transaction) - } - Interaction::Subscribe(req) => { - let mut resume_path = None; - - for item in self.node.subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::Timed(_) => Ok(false), - Interaction::ResumeRead(req) => { - let mut resume_path = None; - - for item in self.node.resume_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - Interaction::ResumeSubscribe(req) => { - let mut resume_path = None; - - for item in self.node.resume_subscribing_read(&req, &accessor) { - if let Some(path) = - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? - { - resume_path = Some(path); - break; - } - } - - req.complete(tx, transaction, resume_path) - } - } - } -} - -pub trait DataHandler { - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result; -} - -impl DataHandler for &mut T -where - T: DataHandler, -{ - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result { - (**self).handle(interaction, tx, transaction) - } -} - -impl<'a, T> DataHandler for DataModel<'a, T> -where - T: Handler, -{ - fn handle( - &mut self, - interaction: Interaction, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result { - DataModel::handle(self, interaction, tx, transaction) - } -} - -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::{ - data_model::objects::asynch::AsyncHandler, - error::Error, - interaction_model::core::{Interaction, Transaction}, - transport::packet::Packet, - }; - - use super::DataModel; - - pub trait AsyncDataHandler { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result; - } - - impl AsyncDataHandler for &mut T - where - T: AsyncDataHandler, - { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result { - (**self).handle(interaction, tx, transaction).await - } - } - - impl<'a, T> AsyncDataHandler for DataModel<'a, T> - where - T: AsyncHandler, - { - async fn handle( - &mut self, - interaction: Interaction<'_>, - tx: &mut Packet<'_>, - transaction: &mut Transaction<'_, '_>, - ) -> Result { - DataModel::handle_async(self, interaction, tx, transaction).await - } + Ok(()) } } diff --git a/matter/src/data_model/objects/dataver.rs b/matter/src/data_model/objects/dataver.rs index f05a383..dcdb42d 100644 --- a/matter/src/data_model/objects/dataver.rs +++ b/matter/src/data_model/objects/dataver.rs @@ -15,11 +15,13 @@ * limitations under the License. */ +use core::cell::Cell; + use crate::utils::rand::Rand; pub struct Dataver { - ver: u32, - changed: bool, + ver: Cell, + changed: Cell, } impl Dataver { @@ -28,25 +30,25 @@ impl Dataver { rand(&mut buf); Self { - ver: u32::from_be_bytes(buf), - changed: false, + ver: Cell::new(u32::from_be_bytes(buf)), + changed: Cell::new(false), } } pub fn get(&self) -> u32 { - self.ver + self.ver.get() } - pub fn changed(&mut self) -> u32 { - (self.ver, _) = self.ver.overflowing_add(1); - self.changed = true; + pub fn changed(&self) -> u32 { + self.ver.set(self.ver.get().overflowing_add(1).0); + self.changed.set(true); self.get() } - pub fn consume_change(&mut self, change: T) -> Option { - if self.changed { - self.changed = false; + pub fn consume_change(&self, change: T) -> Option { + if self.changed.get() { + self.changed.set(false); Some(change) } else { None diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index 70e0db7..73f610b 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -19,12 +19,12 @@ use core::fmt::{Debug, Formatter}; use core::marker::PhantomData; use core::ops::{Deref, DerefMut}; -use crate::interaction_model::core::{IMStatusCode, Transaction}; +use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib::{ AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, }; -use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; +use crate::transport::exchange::Exchange; use crate::{ error::{Error, ErrorCode}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, @@ -32,7 +32,7 @@ use crate::{ }; use log::error; -use super::{AttrDetails, CmdDetails, Handler}; +use super::{AttrDetails, CmdDetails, DataModelHandler}; // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer // may have already started encoding the 'success' headers, we might not want to manage @@ -124,47 +124,75 @@ pub struct AttrDataEncoder<'a, 'b, 'c> { } impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { - pub fn handle_read( - item: Result, + pub async fn handle_read( + item: &Result, AttrStatus>, handler: &T, - tw: &mut TLVWriter, - ) -> Result, Error> { + tw: &mut TLVWriter<'_, '_>, + ) -> Result { let status = match item { Ok(attr) => { - let encoder = AttrDataEncoder::new(&attr, tw); + let encoder = AttrDataEncoder::new(attr, tw); - match handler.read(&attr, encoder) { + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.read(attr, encoder) + } + + #[cfg(feature = "nightly")] + { + handler.read(&attr, encoder).await + } + }; + + match result { Ok(()) => None, Err(e) => { if e.code() == ErrorCode::NoSpace { - return Ok(Some(attr.path().to_gp())); + return Ok(false); } else { attr.status(e.into())? } } } } - Err(status) => Some(status), + Err(status) => Some(status.clone()), }; if let Some(status) = status { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(None) + Ok(true) } - pub fn handle_write( - item: Result<(AttrDetails, TLVElement), AttrStatus>, - handler: &mut T, - tw: &mut TLVWriter, + pub async fn handle_write( + item: &Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, + handler: &T, + tw: &mut TLVWriter<'_, '_>, ) -> Result<(), Error> { let status = match item { - Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) { - Ok(()) => attr.status(IMStatusCode::Success)?, - Err(error) => attr.status(error.into())?, - }, - Err(status) => Some(status), + Ok((attr, data)) => { + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.write(attr, AttrData::new(attr.dataver, data)) + } + + #[cfg(feature = "nightly")] + { + handler + .write(&attr, AttrData::new(attr.dataver, &data)) + .await + } + }; + + match result { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, + } + } + Err(status) => Some(status.clone()), }; if let Some(status) = status { @@ -174,61 +202,6 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn handle_read_async( - item: Result, AttrStatus>, - handler: &T, - tw: &mut TLVWriter<'_, '_>, - ) -> Result, Error> { - let status = match item { - Ok(attr) => { - let encoder = AttrDataEncoder::new(&attr, tw); - - match handler.read(&attr, encoder).await { - Ok(()) => None, - Err(e) => { - if e.code() == ErrorCode::NoSpace { - return Ok(Some(attr.path().to_gp())); - } else { - attr.status(e.into())? - } - } - } - } - Err(status) => Some(status), - }; - - if let Some(status) = status { - AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; - } - - Ok(None) - } - - #[cfg(feature = "nightly")] - pub async fn handle_write_async( - item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, - handler: &mut T, - tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { - let status = match item { - Ok((attr, data)) => match handler - .write(&attr, AttrData::new(attr.dataver, &data)) - .await - { - Ok(()) => attr.status(IMStatusCode::Success)?, - Err(error) => attr.status(error.into())?, - }, - Err(status) => Some(status), - }; - - if let Some(status) = status { - AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; - } - - Ok(()) - } - pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self { Self { dataver_filter: attr.dataver, @@ -365,18 +338,30 @@ pub struct CmdDataEncoder<'a, 'b, 'c> { } impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { - pub fn handle( - item: Result<(CmdDetails, TLVElement), CmdStatus>, - handler: &mut T, - transaction: &mut Transaction, - tw: &mut TLVWriter, + pub async fn handle( + item: &Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, + handler: &T, + tw: &mut TLVWriter<'_, '_>, + exchange: &Exchange<'_>, ) -> Result<(), Error> { let status = match item { Ok((cmd, data)) => { let mut tracker = CmdDataTracker::new(); - let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + let encoder = CmdDataEncoder::new(cmd, &mut tracker, tw); - match handler.invoke(transaction, &cmd, &data, encoder) { + let result = { + #[cfg(not(feature = "nightly"))] + { + handler.invoke(exchange, cmd, data, encoder) + } + + #[cfg(feature = "nightly")] + { + handler.invoke(exchange, &cmd, &data, encoder).await + } + }; + + match result { Ok(()) => cmd.success(&tracker), Err(error) => { error!("Error invoking command: {}", error); @@ -386,7 +371,7 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { } Err(status) => { error!("Error invoking command: {:?}", status); - Some(status) + Some(status.clone()) } }; @@ -397,33 +382,6 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn handle_async( - item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, - handler: &mut T, - transaction: &mut Transaction<'_, '_>, - tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { - let status = match item { - Ok((cmd, data)) => { - let mut tracker = CmdDataTracker::new(); - let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); - - match handler.invoke(transaction, &cmd, &data, encoder).await { - Ok(()) => cmd.success(&tracker), - Err(error) => cmd.status(error.into()), - } - } - Err(status) => Some(status), - }; - - if let Some(status) = status { - InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; - } - - Ok(()) - } - pub fn new( cmd: &CmdDetails, tracker: &'a mut CmdDataTracker, diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 143cad8..03cac3f 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -17,12 +17,25 @@ use crate::{ error::{Error, ErrorCode}, - interaction_model::core::Transaction, tlv::TLVElement, + transport::exchange::Exchange, }; use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; +#[cfg(feature = "nightly")] +pub use asynch::*; + +#[cfg(not(feature = "nightly"))] +pub trait DataModelHandler: super::Metadata + Handler {} +#[cfg(not(feature = "nightly"))] +impl DataModelHandler for T where T: super::Metadata + Handler {} + +#[cfg(feature = "nightly")] +pub trait DataModelHandler: super::asynch::AsyncMetadata + asynch::AsyncHandler {} +#[cfg(feature = "nightly")] +impl DataModelHandler for T where T: super::asynch::AsyncMetadata + asynch::AsyncHandler {} + pub trait ChangeNotifier { fn consume_change(&mut self) -> Option; } @@ -30,13 +43,13 @@ pub trait ChangeNotifier { pub trait Handler { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; - fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { + fn write(&self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { Err(ErrorCode::AttributeNotFound.into()) } fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, _cmd: &CmdDetails, _data: &TLVElement, _encoder: CmdDataEncoder, @@ -45,6 +58,29 @@ pub trait Handler { } } +impl Handler for &T +where + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + (**self).read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + (**self).write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + (**self).invoke(exchange, cmd, data, encoder) + } +} + impl Handler for &mut T where T: Handler, @@ -53,25 +89,52 @@ where (**self).read(attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { (**self).write(attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - (**self).invoke(transaction, cmd, data, encoder) + (**self).invoke(exchange, cmd, data, encoder) } } pub trait NonBlockingHandler: Handler {} +impl NonBlockingHandler for &T where T: NonBlockingHandler {} + impl NonBlockingHandler for &mut T where T: NonBlockingHandler {} +impl Handler for (M, H) +where + H: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.1.read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.1.write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + self.1.invoke(exchange, cmd, data, encoder) + } +} + +impl NonBlockingHandler for (M, H) where H: NonBlockingHandler {} + pub struct EmptyHandler; impl EmptyHandler { @@ -140,7 +203,7 @@ where } } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { self.handler.write(attr, data) } else { @@ -149,16 +212,16 @@ where } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { - self.handler.invoke(transaction, cmd, data, encoder) + self.handler.invoke(exchange, cmd, data, encoder) } else { - self.next.invoke(transaction, cmd, data, encoder) + self.next.invoke(exchange, cmd, data, encoder) } } } @@ -184,6 +247,35 @@ where } } +/// Wrap your `NonBlockingHandler` or `AsyncHandler` implementation in this struct +/// to get your code compilable with and without the `nightly` feature +pub struct HandlerCompat(pub T); + +impl Handler for HandlerCompat +where + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.0.read(attr, encoder) + } + + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.0.write(attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + self.0.invoke(exchange, cmd, data, encoder) + } +} + +impl NonBlockingHandler for HandlerCompat where T: NonBlockingHandler {} + #[allow(unused_macros)] #[macro_export] macro_rules! handler_chain_type { @@ -203,15 +295,15 @@ macro_rules! handler_chain_type { } #[cfg(feature = "nightly")] -pub mod asynch { +mod asynch { use crate::{ data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, error::{Error, ErrorCode}, - interaction_model::core::Transaction, tlv::TLVElement, + transport::exchange::Exchange, }; - use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler}; + use super::{ChainedHandler, EmptyHandler, Handler, HandlerCompat, NonBlockingHandler}; pub trait AsyncHandler { async fn read<'a>( @@ -221,7 +313,7 @@ pub mod asynch { ) -> Result<(), Error>; async fn write<'a>( - &'a mut self, + &'a self, _attr: &'a AttrDetails<'_>, _data: AttrData<'a>, ) -> Result<(), Error> { @@ -229,8 +321,8 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - _transaction: &'a mut Transaction<'_, '_>, + &'a self, + _exchange: &'a Exchange<'_>, _cmd: &'a CmdDetails<'_>, _data: &'a TLVElement<'_>, _encoder: CmdDataEncoder<'a, '_, '_>, @@ -252,7 +344,7 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { @@ -260,19 +352,79 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - (**self).invoke(transaction, cmd, data, encoder).await + (**self).invoke(exchange, cmd, data, encoder).await } } - pub struct Asyncify(pub T); + impl AsyncHandler for &T + where + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).read(attr, encoder).await + } - impl AsyncHandler for Asyncify + async fn write<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + (**self).write(attr, data).await + } + + async fn invoke<'a>( + &'a self, + exchange: &'a Exchange<'_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).invoke(exchange, cmd, data, encoder).await + } + } + + impl AsyncHandler for (M, H) + where + H: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + self.1.read(attr, encoder).await + } + + async fn write<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + self.1.write(attr, data).await + } + + async fn invoke<'a>( + &'a self, + exchange: &'a Exchange<'_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + self.1.invoke(exchange, cmd, data, encoder).await + } + } + + impl AsyncHandler for HandlerCompat where T: NonBlockingHandler, { @@ -285,21 +437,21 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { - Handler::write(&mut self.0, attr, data) + Handler::write(&self.0, attr, data) } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Handler::invoke(&mut self.0, transaction, cmd, data, encoder) + Handler::invoke(&self.0, exchange, cmd, data, encoder) } } @@ -332,7 +484,7 @@ pub mod asynch { } async fn write<'a>( - &'a mut self, + &'a self, attr: &'a AttrDetails<'_>, data: AttrData<'a>, ) -> Result<(), Error> { @@ -345,16 +497,16 @@ pub mod asynch { } async fn invoke<'a>( - &'a mut self, - transaction: &'a mut Transaction<'_, '_>, + &'a self, + exchange: &'a Exchange<'_>, cmd: &'a CmdDetails<'_>, data: &'a TLVElement<'_>, encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { - self.handler.invoke(transaction, cmd, data, encoder).await + self.handler.invoke(exchange, cmd, data, encoder).await } else { - self.next.invoke(transaction, cmd, data, encoder).await + self.next.invoke(exchange, cmd, data, encoder).await } } } diff --git a/matter/src/data_model/objects/metadata.rs b/matter/src/data_model/objects/metadata.rs new file mode 100644 index 0000000..368ff9b --- /dev/null +++ b/matter/src/data_model/objects/metadata.rs @@ -0,0 +1,178 @@ +use crate::data_model::objects::Node; + +#[cfg(feature = "nightly")] +pub use asynch::*; + +use super::HandlerCompat; + +pub trait MetadataGuard { + fn node(&self) -> Node<'_>; +} + +impl MetadataGuard for &T +where + T: MetadataGuard, +{ + fn node(&self) -> Node<'_> { + (**self).node() + } +} + +impl MetadataGuard for &mut T +where + T: MetadataGuard, +{ + fn node(&self) -> Node<'_> { + (**self).node() + } +} + +pub trait Metadata { + type MetadataGuard<'a>: MetadataGuard + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_>; +} + +impl Metadata for &T +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock() + } +} + +impl Metadata for &mut T +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock() + } +} + +impl<'a> MetadataGuard for Node<'a> { + fn node(&self) -> Node<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } +} + +impl<'a> Metadata for Node<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + fn lock(&self) -> Self::MetadataGuard<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } +} + +impl Metadata for (M, H) +where + M: Metadata, +{ + type MetadataGuard<'a> = M::MetadataGuard<'a> + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } +} + +impl Metadata for HandlerCompat +where + T: Metadata, +{ + type MetadataGuard<'a> = T::MetadataGuard<'a> + where + Self: 'a; + + fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::data_model::objects::{HandlerCompat, Node}; + + use super::{Metadata, MetadataGuard}; + + pub trait AsyncMetadata { + type MetadataGuard<'a>: MetadataGuard + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_>; + } + + impl AsyncMetadata for &T + where + T: AsyncMetadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock().await + } + } + + impl AsyncMetadata for &mut T + where + T: AsyncMetadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + (**self).lock().await + } + } + + impl<'a> AsyncMetadata for Node<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + Node { + id: self.id, + endpoints: self.endpoints, + } + } + } + + impl AsyncMetadata for (M, H) + where + M: AsyncMetadata, + { + type MetadataGuard<'a> = M::MetadataGuard<'a> + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock().await + } + } + + impl AsyncMetadata for HandlerCompat + where + T: Metadata, + { + type MetadataGuard<'a> = T::MetadataGuard<'a> + where + Self: 'a; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + self.0.lock() + } + } +} diff --git a/matter/src/data_model/objects/mod.rs b/matter/src/data_model/objects/mod.rs index 1bd326e..b8b0511 100644 --- a/matter/src/data_model/objects/mod.rs +++ b/matter/src/data_model/objects/mod.rs @@ -41,6 +41,9 @@ pub use handler::*; mod dataver; pub use dataver::*; +mod metadata; +pub use metadata::*; + pub type EndptId = u16; pub type ClusterId = u32; pub type AttrId = u16; diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 41720b6..1ffa896 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -17,9 +17,10 @@ use crate::{ acl::Accessor, + alloc, data_model::objects::Endpoint, interaction_model::{ - core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq}, + core::IMStatusCode, messages::{ ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, @@ -27,7 +28,7 @@ use crate::{ }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVArray, TLVArrayIter, TLVElement}, + tlv::{TLVArray, TLVElement}, }; use core::{ fmt, @@ -57,41 +58,6 @@ where } } -pub trait Iterable { - type Item; - - type Iterator<'a>: Iterator - where - Self: 'a; - - fn iter(&self) -> Self::Iterator<'_>; -} - -impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> { - type Item = DataVersionFilter; - - type Iterator<'i> = WildcardIter, DataVersionFilter> where Self: 'i; - - fn iter(&self) -> Self::Iterator<'_> { - if let Some(filters) = self { - WildcardIter::Wildcard(filters.iter()) - } else { - WildcardIter::None - } - } -} - -impl<'a> Iterable for &'a [DataVersionFilter] { - type Item = DataVersionFilter; - - type Iterator<'i> = core::iter::Cloned> where Self: 'i; - - fn iter(&self) -> Self::Iterator<'_> { - let slice: &[DataVersionFilter] = self; - slice.iter().cloned() - } -} - #[derive(Debug, Clone)] pub struct Node<'a> { pub id: u16, @@ -102,6 +68,7 @@ impl<'a> Node<'a> { pub fn read<'s, 'm>( &'s self, req: &'m ReadReq, + from: Option, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm where @@ -114,30 +81,14 @@ impl<'a> Node<'a> { req.dataver_filters.as_ref(), req.fabric_filtered, accessor, - None, - ) - } - - pub fn resume_read<'s, 'm>( - &'s self, - req: &'m ResumeReadReq, - accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, - { - self.read_attr_requests( - req.paths.iter().cloned(), - req.filters.as_slice(), - req.fabric_filtered, - accessor, - Some(req.resume_path.clone()), + from, ) } pub fn subscribing_read<'s, 'm>( &'s self, req: &'m SubscribeReq, + from: Option, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm where @@ -150,31 +101,14 @@ impl<'a> Node<'a> { req.dataver_filters.as_ref(), req.fabric_filtered, accessor, - None, + from, ) } - pub fn resume_subscribing_read<'s, 'm>( - &'s self, - req: &'m ResumeSubscribeReq, - accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, - { - self.read_attr_requests( - req.paths.iter().cloned(), - req.filters.as_slice(), - req.fabric_filtered, - accessor, - Some(req.resume_path.clone().unwrap()), - ) - } - - fn read_attr_requests<'s, 'm, P, D>( + fn read_attr_requests<'s, 'm, P>( &'s self, attr_requests: P, - dataver_filters: D, + dataver_filters: Option<&'m TLVArray>, fabric_filtered: bool, accessor: &'m Accessor<'m>, from: Option, @@ -182,11 +116,9 @@ impl<'a> Node<'a> { where 's: 'm, P: Iterator + 'm, - D: Iterable + Clone + 'm, { - attr_requests.flat_map(move |path| { + alloc!(attr_requests.flat_map(move |path| { if path.to_gp().is_wildcard() { - let dataver_filters = dataver_filters.clone(); let from = from.clone(); let iter = self @@ -204,10 +136,14 @@ impl<'a> Node<'a> { .is_ok() }) .map(move |(ep, cl, attr)| { - let dataver = dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) - .then_some(filter.data_ver) - }); + let dataver = if let Some(dataver_filters) = dataver_filters { + dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) + .then_some(filter.data_ver) + }) + } else { + None + }; Ok(AttrDetails { node: self, @@ -230,10 +166,14 @@ impl<'a> Node<'a> { let result = match self.check_attribute(accessor, ep, cl, attr, false) { Ok(()) => { - let dataver = dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep && filter.path.cluster == cl) - .then_some(filter.data_ver) - }); + let dataver = if let Some(dataver_filters) = dataver_filters { + dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }) + } else { + None + }; Ok(AttrDetails { node: self, @@ -252,7 +192,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } pub fn write<'m>( @@ -260,7 +200,7 @@ impl<'a> Node<'a> { req: &'m WriteReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), AttrStatus>> + 'm { - req.write_requests.iter().flat_map(move |attr_data| { + alloc!(req.write_requests.iter().flat_map(move |attr_data| { if attr_data.path.cluster.is_none() { WildcardIter::Single(once(Err(AttrStatus::new( &attr_data.path.to_gp(), @@ -332,7 +272,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } pub fn invoke<'m>( @@ -340,7 +280,8 @@ impl<'a> Node<'a> { req: &'m InvReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), CmdStatus>> + 'm { - req.inv_requests + alloc!(req + .inv_requests .iter() .flat_map(|inv_requests| inv_requests.iter()) .flat_map(move |cmd_data| { @@ -393,7 +334,7 @@ impl<'a> Node<'a> { WildcardIter::Single(once(result)) } - }) + })) } fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 21691cd..69df3bd 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -46,7 +46,7 @@ pub const CLUSTERS: [Cluster<'static>; 7] = [ access_control::CLUSTER, ]; -pub fn endpoint(id: EndptId) -> Endpoint<'static> { +pub const fn endpoint(id: EndptId) -> Endpoint<'static> { Endpoint { id, device_type: super::device_types::DEV_TYPE_ROOT_NODE, diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 15c803f..3cce0f7 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -19,11 +19,11 @@ use core::cell::RefCell; use core::convert::TryInto; use crate::data_model::objects::*; -use crate::interaction_model::core::Transaction; use crate::mdns::Mdns; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; +use crate::transport::exchange::Exchange; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -84,8 +84,8 @@ pub const CLUSTER: Cluster<'static> = Cluster { ], commands: &[ Commands::OpenCommWindow as _, - Commands::OpenBasicCommWindow as _, - Commands::RevokeComm as _, + // Commands::OpenBasicCommWindow as _, + // Commands::RevokeComm as _, ], }; @@ -133,7 +133,7 @@ impl<'a> AdminCommCluster<'a> { } pub fn invoke( - &mut self, + &self, cmd: &CmdDetails, data: &TLVElement, _encoder: CmdDataEncoder, @@ -148,7 +148,7 @@ impl<'a> AdminCommCluster<'a> { Ok(()) } - fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> { + fn handle_command_opencomm_win(&self, data: &TLVElement) -> Result<(), Error> { cmd_enter!("Open Commissioning Window"); let req = OpenCommWindowReq::from_tlv(data)?; let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); @@ -166,8 +166,8 @@ impl<'a> Handler for AdminCommCluster<'a> { } fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 78c3bef..0784bae 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -20,8 +20,8 @@ use core::convert::TryInto; use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; -use crate::interaction_model::core::Transaction; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::exchange::Exchange; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -171,19 +171,19 @@ impl<'a> GenCommCluster<'a> { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::ArmFailsafe => self.handle_command_armfailsafe(transaction, data, encoder)?, + Commands::ArmFailsafe => self.handle_command_armfailsafe(exchange, data, encoder)?, Commands::SetRegulatoryConfig => { - self.handle_command_setregulatoryconfig(transaction, data, encoder)? + self.handle_command_setregulatoryconfig(exchange, data, encoder)? } Commands::CommissioningComplete => { - self.handle_command_commissioningcomplete(transaction, encoder)?; + self.handle_command_commissioningcomplete(exchange, encoder)?; } } @@ -193,8 +193,8 @@ impl<'a> GenCommCluster<'a> { } fn handle_command_armfailsafe( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -207,7 +207,7 @@ impl<'a> GenCommCluster<'a> { .borrow_mut() .arm( p.expiry_len, - transaction.session().get_session_mode().clone(), + exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?, ) .is_err() { @@ -225,13 +225,12 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::ArmFailsafeResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_setregulatoryconfig( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -252,20 +251,22 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::SetRegulatoryConfigResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_commissioningcomplete( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Commissioning Complete"); let mut status: u8 = CommissioningError::Ok as u8; // Has to be a Case Session - if transaction.session().get_local_fabric_idx().is_none() { + if exchange + .with_session(|sess| Ok(sess.get_local_fabric_idx()))? + .is_none() + { status = CommissioningError::ErrInvalidAuth as u8; } @@ -274,7 +275,7 @@ impl<'a> GenCommCluster<'a> { if self .failsafe .borrow_mut() - .disarm(transaction.session().get_session_mode().clone()) + .disarm(exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; @@ -289,7 +290,6 @@ impl<'a> GenCommCluster<'a> { .with_command(RespCommands::CommissioningCompleteResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } } @@ -300,13 +300,13 @@ impl<'a> Handler for GenCommCluster<'a> { } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - GenCommCluster::invoke(self, transaction, cmd, data, encoder) + GenCommCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 7fb1e37..8b66cb4 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -24,9 +24,9 @@ use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; -use crate::interaction_model::core::Transaction; use crate::mdns::Mdns; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; @@ -289,26 +289,26 @@ impl<'a> NocCluster<'a> { } pub fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { - Commands::AddNOC => self.handle_command_addnoc(transaction, data, encoder)?, - Commands::CSRReq => self.handle_command_csrrequest(transaction, data, encoder)?, + Commands::AddNOC => self.handle_command_addnoc(exchange, data, encoder)?, + Commands::CSRReq => self.handle_command_csrrequest(exchange, data, encoder)?, Commands::AddTrustedRootCert => { - self.handle_command_addtrustedrootcert(transaction, data)? + self.handle_command_addtrustedrootcert(exchange, data)? } - Commands::AttReq => self.handle_command_attrequest(transaction, data, encoder)?, + Commands::AttReq => self.handle_command_attrequest(exchange, data, encoder)?, Commands::CertChainReq => { - self.handle_command_certchainrequest(transaction, data, encoder)? + self.handle_command_certchainrequest(exchange, data, encoder)? } Commands::UpdateFabricLabel => { - self.handle_command_updatefablabel(transaction, data, encoder)?; + self.handle_command_updatefablabel(exchange, data, encoder)?; } - Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, + Commands::RemoveFabric => self.handle_command_rmfabric(exchange, data, encoder)?, } self.data_ver.changed(); @@ -323,13 +323,12 @@ impl<'a> NocCluster<'a> { } fn _handle_command_addnoc( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, ) -> Result { - let noc_data = transaction - .session_mut() - .take_noc_data() + let noc_data = exchange + .with_session_mut(|sess| Ok(sess.take_noc_data()))? .ok_or(NocStatus::MissingCsr)?; if !self @@ -411,42 +410,42 @@ impl<'a> NocCluster<'a> { } fn handle_command_updatefablabel( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("Update Fabric Label"); let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; - let (result, fab_idx) = - if let SessionMode::Case(c) = transaction.session().get_session_mode() { - if self - .fabric_mgr - .borrow_mut() - .set_label( - c.fab_idx, - req.label.as_str().map_err(Error::map_invalid_data_type)?, - ) - .is_err() - { - (NocStatus::LabelConflict, c.fab_idx) - } else { - (NocStatus::Ok, c.fab_idx) - } + let (result, fab_idx) = if let SessionMode::Case(c) = + exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? + { + if self + .fabric_mgr + .borrow_mut() + .set_label( + c.fab_idx, + req.label.as_str().map_err(Error::map_invalid_data_type)?, + ) + .is_err() + { + (NocStatus::LabelConflict, c.fab_idx) } else { - // Update Fabric Label not allowed - (NocStatus::InvalidFabricIndex, 0) - }; + (NocStatus::Ok, c.fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; Self::create_nocresponse(encoder, result, fab_idx, "")?; - transaction.complete(); Ok(()) } fn handle_command_rmfabric( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -459,7 +458,7 @@ impl<'a> NocCluster<'a> { .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); - transaction.terminate(); + // TODO: transaction.terminate(); Ok(()) } else { Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") @@ -467,28 +466,27 @@ impl<'a> NocCluster<'a> { } fn handle_command_addnoc( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { cmd_enter!("AddNOC"); - let (status, fab_idx) = match self._handle_command_addnoc(transaction, data) { + let (status, fab_idx) = match self._handle_command_addnoc(exchange, data) { Ok(fab_idx) => (NocStatus::Ok, fab_idx), Err(NocError::Status(status)) => (status, 0), Err(NocError::Error(error)) => Err(error)?, }; Self::create_nocresponse(encoder, status, fab_idx, "")?; - transaction.complete(); Ok(()) } fn handle_command_attrequest( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -498,7 +496,10 @@ impl<'a> NocCluster<'a> { info!("Received Attestation Nonce:{:?}", req.str); let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); + exchange.with_session(|sess| { + attest_challenge.copy_from_slice(sess.get_att_challenge()); + Ok(()) + })?; let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; @@ -522,13 +523,12 @@ impl<'a> NocCluster<'a> { writer.complete()?; - transaction.complete(); Ok(()) } fn handle_command_certchainrequest( - &mut self, - transaction: &mut Transaction, + &self, + _exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -549,13 +549,12 @@ impl<'a> NocCluster<'a> { .with_command(RespCommands::CertChainResp as _)? .set(cmd_data)?; - transaction.complete(); Ok(()) } fn handle_command_csrrequest( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { @@ -570,7 +569,10 @@ impl<'a> NocCluster<'a> { let noc_keypair = KeyPair::new(self.rand)?; let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(transaction.session().get_att_challenge()); + exchange.with_session(|sess| { + attest_challenge.copy_from_slice(sess.get_att_challenge()); + Ok(()) + })?; let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; @@ -591,15 +593,17 @@ impl<'a> NocCluster<'a> { let noc_data = NocData::new(noc_keypair); // Store this in the session data instead of cluster data, so it gets cleared // if the session goes away for some reason - transaction.session_mut().set_noc_data(noc_data); + exchange.with_session_mut(|sess| { + sess.set_noc_data(noc_data); + Ok(()) + })?; - transaction.complete(); Ok(()) } fn handle_command_addtrustedrootcert( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, data: &TLVElement, ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); @@ -608,25 +612,26 @@ impl<'a> NocCluster<'a> { } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary - match transaction.session().get_session_mode() { + match exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? { SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, SessionMode::Pase => { - let noc_data = transaction - .session_mut() - .get_noc_data() - .ok_or(ErrorCode::NoSession)?; + exchange.with_session_mut(|sess| { + let noc_data = sess.get_noc_data().ok_or(ErrorCode::NoSession)?; - let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; - info!("Received Trusted Cert:{:x?}", req.str); + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; + info!("Received Trusted Cert:{:x?}", req.str); + + noc_data.root_ca = heapless::Vec::from_slice(req.str.0) + .map_err(|_| ErrorCode::BufferTooSmall)?; + + Ok(()) + })?; - noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; // TODO } _ => (), } - transaction.complete(); Ok(()) } } @@ -637,13 +642,13 @@ impl<'a> Handler for NocCluster<'a> { } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - NocCluster::invoke(self, transaction, cmd, data, encoder) + NocCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index 17c88e3..8301b46 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -132,7 +132,7 @@ impl<'a> AccessControlCluster<'a> { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { match attr.attr_id.try_into()? { Attributes::Acl(_) => { attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { @@ -151,7 +151,7 @@ impl<'a> AccessControlCluster<'a> { /// This takes care of 4 things, add item, edit item, delete item, delete list. /// Care about fabric-scoped behaviour is taken fn write_acl_attr( - &mut self, + &self, op: &ListOperation, data: &TLVElement, fab_idx: u8, @@ -185,7 +185,7 @@ impl<'a> Handler for AccessControlCluster<'a> { AccessControlCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { AccessControlCluster::write(self, attr, data) } } @@ -220,7 +220,7 @@ mod tests { let mut tw = TLVWriter::new(&mut writebuf); let acl_mgr = RefCell::new(AclMgr::new()); - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); @@ -258,7 +258,7 @@ mod tests { for i in &verifier { acl_mgr.borrow_mut().add(i.clone()).unwrap(); } - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); @@ -295,7 +295,7 @@ mod tests { for i in &input { acl_mgr.borrow_mut().add(i.clone()).unwrap(); } - let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // data is don't-care actually let data = TLVElement::new(TagType::Anonymous, ElementType::True); diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index cc763a8..4ce3583 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -15,36 +15,28 @@ * limitations under the License. */ -use core::sync::atomic::{AtomicU32, Ordering}; use core::time::Duration; use crate::{ - data_model::core::DataHandler, + acl::Accessor, error::*, - tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::{Exchange, ExchangeCtx}, - packet::Packet, - proto_ctx::ProtoCtx, - session::Session, - }, + tlv::{get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + transport::{exchange::Exchange, packet::Packet}, + utils::epoch::Epoch, }; -use log::{error, info}; -use num; +use log::error; +use num::{self, FromPrimitive}; use num_derive::FromPrimitive; -use owo_colors::OwoColorize; -use super::messages::{ - ib::{AttrPath, DataVersionFilter}, - msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq}, - GenericPath, +use super::messages::msg::{ + self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq, }; #[macro_export] macro_rules! cmd_enter { ($e:expr) => {{ use owo_colors::OwoColorize; - info! {"{} {}", "Handling Command".cyan(), $e.cyan()} + info! {"{} {}", "Handling command".cyan(), $e.cyan()} }}; } @@ -104,7 +96,7 @@ impl From for IMStatusCode { impl FromTLV<'_> for IMStatusCode { fn from_tlv(t: &TLVElement) -> Result { - num::FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) + FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -114,7 +106,7 @@ impl ToTLV for IMStatusCode { } } -#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] +#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)] pub enum OpCode { Reserved = 0, StatusResponse = 1, @@ -129,208 +121,16 @@ pub enum OpCode { TimedRequest = 10, } -#[derive(PartialEq)] -pub enum TransactionState { - Ongoing, - Complete, - Terminate, -} -pub struct Transaction<'a, 'b> { - state: TransactionState, - ctx: &'a mut ExchangeCtx<'b>, -} - -impl<'a, 'b> Transaction<'a, 'b> { - pub fn new(ctx: &'a mut ExchangeCtx<'b>) -> Self { - Self { - state: TransactionState::Ongoing, - ctx, - } - } - - pub fn exch(&self) -> &Exchange { - self.ctx.exch - } - - pub fn exch_mut(&mut self) -> &mut Exchange { - self.ctx.exch - } - - pub fn session(&self) -> &Session { - self.ctx.sess.session() - } - - pub fn session_mut(&mut self) -> &mut Session { - self.ctx.sess.session_mut() - } - - /// Terminates the transaction, no communication (even ACKs) happens hence forth - pub fn terminate(&mut self) { - self.state = TransactionState::Terminate - } - - pub fn is_terminate(&self) -> bool { - self.state == TransactionState::Terminate - } - /// Marks the transaction as completed from the application's perspective - pub fn complete(&mut self) { - self.state = TransactionState::Complete - } - - pub fn is_complete(&self) -> bool { - self.state == TransactionState::Complete - } - - pub fn set_timeout(&mut self, timeout: u64) { - let now = (self.ctx.epoch)(); - - self.ctx - .exch - .set_data_time(now.checked_add(Duration::from_millis(timeout))); - } - - pub fn get_timeout(&mut self) -> Option { - self.ctx.exch.get_data_time() - } - - pub fn has_timed_out(&self) -> bool { - if let Some(timeout) = self.ctx.exch.get_data_time() { - if (self.ctx.epoch)() > timeout { - return true; - } - } - false - } -} - /* Interaction Model ID as per the Matter Spec */ pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; -const MAX_RESUME_PATHS: usize = 32; -const MAX_RESUME_DATAVER_FILTERS: usize = 32; - // This is the amount of space we reserve for other things to be attached towards // the end of long reads. const LONG_READS_TLV_RESERVE_SIZE: usize = 24; -// TODO: For now... -static SUBS_ID: AtomicU32 = AtomicU32::new(1); - -pub enum Interaction<'a> { - Read(ReadReq<'a>), - Write(WriteReq<'a>), - Invoke(InvReq<'a>), - Subscribe(SubscribeReq<'a>), - Timed(TimedReq), - ResumeRead(ResumeReadReq), - ResumeSubscribe(ResumeSubscribeReq), -} - -impl<'a> Interaction<'a> { - fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { - let opcode: OpCode = rx.get_proto_opcode()?; - - let rx_data = rx.as_slice(); - - info!("{} {:?}", "Received command".cyan(), opcode); - print_tlv_list(rx_data); - - match opcode { - OpCode::ReadRequest => Ok(Some(Self::Read(ReadReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?))), - OpCode::WriteRequest => Ok(Some(Self::Write(WriteReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::InvokeRequest => Ok(Some(Self::Invoke(InvReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::SubscribeRequest => Ok(Some(Self::Subscribe(SubscribeReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - OpCode::StatusResponse => { - let resp = StatusResp::from_tlv(&get_root_node_struct(rx_data)?)?; - - if resp.status == IMStatusCode::Success { - if let Some(req) = transaction.exch_mut().take_suspended_read_req() { - Ok(Some(Self::ResumeRead(req))) - } else if let Some(req) = transaction.exch_mut().take_suspended_subscribe_req() - { - Ok(Some(Self::ResumeSubscribe(req))) - } else { - Ok(None) - } - } else { - Ok(None) - } - } - OpCode::TimedRequest => Ok(Some(Self::Timed(TimedReq::from_tlv( - &get_root_node_struct(rx_data)?, - )?))), - _ => { - error!("Opcode not handled: {:?}", opcode); - Err(ErrorCode::InvalidOpcode.into()) - } - } - } - - pub fn initiate( - rx: &'a Packet, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result, Error> { - if let Some(interaction) = Self::new(rx, transaction)? { - tx.reset(); - - let initiated = match &interaction { - Interaction::Read(req) => req.initiate(tx, transaction)?, - Interaction::Write(req) => req.initiate(tx, transaction)?, - Interaction::Invoke(req) => req.initiate(tx, transaction)?, - Interaction::Subscribe(req) => req.initiate(tx, transaction)?, - Interaction::Timed(req) => { - req.process(tx, transaction)?; - false - } - Interaction::ResumeRead(req) => req.initiate(tx, transaction)?, - Interaction::ResumeSubscribe(req) => req.initiate(tx, transaction)?, - }; - - Ok(initiated.then_some(interaction)) - } else { - Ok(None) - } - } - - fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) - } -} - impl<'a> ReadReq<'a> { - fn suspend(self, resume_path: GenericPath) -> ResumeReadReq { - ResumeReadReq { - paths: self - .attr_requests - .iter() - .flat_map(|attr_requests| attr_requests.iter()) - .collect(), - filters: self - .dataver_filters - .iter() - .flat_map(|filters| filters.iter()) - .collect(), - fabric_filtered: self.fabric_filtered, - resume_path, - } - } - - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + pub fn tx_start<'r, 'p>(&self, tx: &'r mut Packet<'p>) -> Result, Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); @@ -342,47 +142,37 @@ impl<'a> ReadReq<'a> { tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; } - Ok(true) + Ok(tw) } - pub fn complete( - self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { + pub fn tx_finish_chunk(&self, tx: &mut Packet) -> Result<(), Error> { + self.complete(tx, true) + } + + pub fn tx_finish(&self, tx: &mut Packet) -> Result<(), Error> { + self.complete(tx, false) + } + + fn complete(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { let mut tw = Self::restore_long_read_space(tx)?; if self.attr_requests.is_some() { tw.end_container()?; } - let more_chunks = if let Some(resume_path) = resume_path { + if more_chunks { tw.bool( TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), true, )?; - - transaction - .exch_mut() - .set_suspended_read_req(self.suspend(resume_path)); - true - } else { - false - }; + } tw.bool( TagType::Context(msg::ReportDataTag::SupressResponse as u8), !more_chunks, )?; - tw.end_container()?; - - if !more_chunks { - transaction.complete(); - } - - Ok(true) + tw.end_container() } fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { @@ -401,14 +191,18 @@ impl<'a> ReadReq<'a> { } impl<'a> WriteReq<'a> { - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - if transaction.has_timed_out() { - Interaction::create_status_response(tx, IMStatusCode::Timeout)?; + pub fn tx_start<'r, 'p>( + &self, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, + ) -> Result>, Error> { + if has_timed_out(epoch, timeout) { + Interaction::status_response(tx, IMStatusCode::Timeout)?; - transaction.complete(); - - Ok(false) + Ok(None) } else { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::WriteResponse as u8); @@ -417,47 +211,40 @@ impl<'a> WriteReq<'a> { tw.start_struct(TagType::Anonymous)?; tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - Ok(true) + Ok(Some(tw)) } } - pub fn complete(self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - let suppress = self.supress_response.unwrap_or_default(); - + pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { let mut tw = TLVWriter::new(tx.get_writebuf()?); tw.end_container()?; - tw.end_container()?; - - transaction.complete(); - - Ok(if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - }) + tw.end_container() } } impl<'a> InvReq<'a> { - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - if transaction.has_timed_out() { - Interaction::create_status_response(tx, IMStatusCode::Timeout)?; + pub fn tx_start<'r, 'p>( + &self, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, + ) -> Result>, Error> { + if has_timed_out(epoch, timeout) { + Interaction::status_response(tx, IMStatusCode::Timeout)?; - transaction.complete(); - - Ok(false) + Ok(None) } else { - let timed_tx = transaction.get_timeout().map(|_| true); + let timed_tx = timeout.map(|_| true); let timed_request = self.timed_request.filter(|a| *a); // Either both should be None, or both should be Some(true) if timed_tx != timed_request { - Interaction::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + Interaction::status_response(tx, IMStatusCode::TimedRequestMisMatch)?; - Ok(false) + Ok(None) } else { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::InvokeResponse as u8); @@ -475,77 +262,45 @@ impl<'a> InvReq<'a> { tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; } - Ok(true) + Ok(Some(tw)) } } } - pub fn complete(self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - let suppress = self.suppress_response.unwrap_or_default(); - + pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> { let mut tw = TLVWriter::new(tx.get_writebuf()?); if self.inv_requests.is_some() { tw.end_container()?; } - tw.end_container()?; - - Ok(if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - }) + tw.end_container() } } impl TimedReq { - pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::StatusResponse as u8); + pub fn timeout(&self, epoch: Epoch) -> Duration { + epoch() + .checked_add(Duration::from_millis(self.timeout as _)) + .unwrap() + } - let mut tw = TLVWriter::new(tx.get_writebuf()?); + pub fn tx_process(self, tx: &mut Packet<'_>, epoch: Epoch) -> Result { + Interaction::status_response(tx, IMStatusCode::Success)?; - transaction.set_timeout(self.timeout.into()); - - let status = StatusResp { - status: IMStatusCode::Success, - }; - - status.to_tlv(&mut tw, TagType::Anonymous)?; - - Ok(()) + Ok(epoch() + .checked_add(Duration::from_millis(self.timeout as _)) + .unwrap()) } } impl<'a> SubscribeReq<'a> { - fn suspend( + pub fn tx_start<'r, 'p>( &self, - resume_path: Option, + tx: &'r mut Packet<'p>, subscription_id: u32, - ) -> ResumeSubscribeReq { - ResumeSubscribeReq { - subscription_id, - paths: self - .attr_requests - .iter() - .flat_map(|attr_requests| attr_requests.iter()) - .collect(), - filters: self - .dataver_filters - .iter() - .flat_map(|filters| filters.iter()) - .collect(), - fabric_filtered: self.fabric_filtered, - resume_path, - keep_subs: self.keep_subs, - min_int_floor: self.min_int_floor, - max_int_ceil: self.max_int_ceil, - } - } - - fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + ) -> Result, Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); @@ -553,9 +308,6 @@ impl<'a> SubscribeReq<'a> { tw.start_struct(TagType::Anonymous)?; - let subscription_id = SUBS_ID.fetch_add(1, Ordering::SeqCst); - transaction.exch_mut().set_subscription_id(subscription_id); - tw.u32( TagType::Context(msg::ReportDataTag::SubscriptionId as u8), subscription_id, @@ -565,282 +317,417 @@ impl<'a> SubscribeReq<'a> { tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; } - Ok(true) + Ok(tw) } - pub fn complete( - self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { + pub fn tx_finish_chunk(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> { let mut tw = ReadReq::restore_long_read_space(tx)?; if self.attr_requests.is_some() { tw.end_container()?; } - if resume_path.is_some() { + if more_chunks { tw.bool( TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), true, )?; } - let subscription_id = transaction.exch_mut().take_subscription_id().unwrap(); - - transaction - .exch_mut() - .set_suspended_subscribe_req(self.suspend(resume_path, subscription_id)); - tw.bool( TagType::Context(msg::ReportDataTag::SupressResponse as u8), false, )?; - tw.end_container()?; - - Ok(true) + tw.end_container() } -} -#[derive(Debug)] -pub struct ResumeReadReq { - pub paths: heapless::Vec, - pub filters: heapless::Vec, - pub fabric_filtered: bool, - pub resume_path: GenericPath, -} - -impl ResumeReadReq { - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + pub fn tx_process_final(&self, tx: &mut Packet, subscription_id: u32) -> Result<(), Error> { + tx.reset(); tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); - tx.set_proto_opcode(OpCode::ReportData as u8); + tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - let mut tw = ReadReq::reserve_long_read_space(tx)?; + let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; - - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - - Ok(true) + let resp = SubscribeResp::new(subscription_id, 40); + resp.to_tlv(&mut tw, TagType::Anonymous) } +} - pub fn complete( - mut self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { - let mut tw = ReadReq::restore_long_read_space(tx)?; +pub struct ReadDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + completed: bool, +} - tw.end_container()?; - - let continue_interaction = if let Some(resume_path) = resume_path { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; - - self.resume_path = resume_path; - transaction.exch_mut().set_suspended_read_req(self); - true - } else { - false - }; - - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - !continue_interaction, - )?; - - tw.end_container()?; - - if !continue_interaction { - transaction.complete(); +impl<'a, 'r, 'p> ReadDriver<'a, 'r, 'p> { + fn new(exchange: &'r mut Exchange<'a>, tx: &'r mut Packet<'p>, rx: &'r mut Packet<'p>) -> Self { + Self { + exchange, + tx, + rx, + completed: false, } - - Ok(true) } -} -#[derive(Debug)] -pub struct ResumeSubscribeReq { - pub subscription_id: u32, - pub paths: heapless::Vec, - pub filters: heapless::Vec, - pub fabric_filtered: bool, - pub resume_path: Option, - pub keep_subs: bool, - pub min_int_floor: u16, - pub max_int_ceil: u16, -} + fn start(&mut self, req: &ReadReq) -> Result<(), Error> { + req.tx_start(self.tx)?; -impl ResumeSubscribeReq { - fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); + Ok(()) + } - if self.resume_path.is_some() { - tx.set_proto_opcode(OpCode::ReportData as u8); + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } - let mut tw = ReadReq::reserve_long_read_space(tx)?; + pub fn writer(&mut self) -> Result, Error> { + if self.completed { + Err(ErrorCode::Invalid.into()) // TODO + } else { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + } - tw.start_struct(TagType::Anonymous)?; + pub async fn send_chunk(&mut self, req: &ReadReq<'_>) -> Result { + req.tx_finish_chunk(self.tx)?; - tw.u32( - TagType::Context(msg::ReportDataTag::SubscriptionId as u8), - self.subscription_id, - )?; - - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; + Ok(false) + } else { + req.tx_start(self.tx)?; + Ok(true) + } + } + + pub async fn complete(&mut self, req: &ReadReq<'_>) -> Result<(), Error> { + req.tx_finish(self.tx)?; + + self.exchange.send_complete(self.tx).await + } +} + +pub struct WriteDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, +} + +impl<'a, 'r, 'p> WriteDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + epoch: Epoch, + timeout: Option, + tx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + epoch, + timeout, + } + } + + async fn start(&mut self, req: &WriteReq<'_>) -> Result { + if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { Ok(true) } else { - tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - let resp = SubscribeResp::new(self.subscription_id, 40); - resp.to_tlv(&mut tw, TagType::Anonymous)?; + self.exchange.send_complete(self.tx).await?; Ok(false) } } - pub fn complete( - mut self, - tx: &mut Packet, - transaction: &mut Transaction, - resume_path: Option, - ) -> Result { - if self.resume_path.is_none() { - // Should not get here as initiate() should've sent the subscribe response already - panic!("Subscription was already processed"); + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } + + pub fn writer(&mut self) -> Result, Error> { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + + pub async fn complete(&mut self, req: &WriteReq<'_>) -> Result<(), Error> { + if !req.supress_response.unwrap_or_default() { + req.tx_finish(self.tx)?; + self.exchange.send_complete(self.tx).await?; } - // Completing a ReportData message - - let mut tw = ReadReq::restore_long_read_space(tx)?; - - tw.end_container()?; - - if resume_path.is_some() { - tw.bool( - TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), - true, - )?; - } - - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - false, - )?; - - tw.end_container()?; - - self.resume_path = resume_path; - transaction.exch_mut().set_suspended_subscribe_req(self); - - Ok(true) + Ok(()) } } -pub trait InteractionHandler { - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result; +pub struct InvokeDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + epoch: Epoch, + timeout: Option, } -impl InteractionHandler for &mut T -where - T: InteractionHandler, -{ - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - (**self).handle(ctx) +impl<'a, 'r, 'p> InvokeDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + epoch: Epoch, + timeout: Option, + tx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + epoch, + timeout, + } + } + + async fn start(&mut self, req: &InvReq<'_>) -> Result { + if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() { + Ok(true) + } else { + self.exchange.send_complete(self.tx).await?; + + Ok(false) + } + } + + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } + + pub fn writer(&mut self) -> Result, Error> { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + + pub fn writer_exchange(&mut self) -> Result<(TLVWriter<'_, 'p>, &Exchange<'a>), Error> { + Ok((TLVWriter::new(self.tx.get_writebuf()?), (self.exchange))) + } + + pub async fn complete(&mut self, req: &InvReq<'_>) -> Result<(), Error> { + if !req.suppress_response.unwrap_or_default() { + req.tx_finish(self.tx)?; + self.exchange.send_complete(self.tx).await?; + } + + Ok(()) } } -pub struct InteractionModel(pub T); +pub struct SubscribeDriver<'a, 'r, 'p> { + exchange: &'r mut Exchange<'a>, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + subscription_id: u32, + completed: bool, +} -impl InteractionModel -where - T: DataHandler, -{ - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - let mut transaction = Transaction::new(&mut ctx.exch_ctx); +impl<'a, 'r, 'p> SubscribeDriver<'a, 'r, 'p> { + fn new( + exchange: &'r mut Exchange<'a>, + subscription_id: u32, + tx: &'r mut Packet<'p>, + rx: &'r mut Packet<'p>, + ) -> Self { + Self { + exchange, + tx, + rx, + subscription_id, + completed: false, + } + } - let reply = - if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { - self.0.handle(interaction, ctx.tx, &mut transaction)? + fn start(&mut self, req: &SubscribeReq) -> Result<(), Error> { + req.tx_start(self.tx, self.subscription_id)?; + + Ok(()) + } + + pub fn accessor(&self) -> Result, Error> { + self.exchange.accessor() + } + + pub fn writer(&mut self) -> Result, Error> { + if self.completed { + Err(ErrorCode::Invalid.into()) // TODO + } else { + Ok(TLVWriter::new(self.tx.get_writebuf()?)) + } + } + + pub async fn send_chunk(&mut self, req: &SubscribeReq<'_>) -> Result { + req.tx_finish_chunk(self.tx, true)?; + + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; + Ok(false) + } else { + req.tx_start(self.tx, self.subscription_id)?; + + Ok(true) + } + } + + pub async fn complete(&mut self, req: &SubscribeReq<'_>) -> Result<(), Error> { + if !self.completed { + req.tx_finish_chunk(self.tx, false)?; + + if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success { + self.completed = true; } else { - true - }; - - if transaction.is_complete() { - transaction.exch_mut().close(); + req.tx_process_final(self.tx, self.subscription_id)?; + self.exchange.send_complete(self.tx).await?; + } } - Ok(reply) + Ok(()) } } -#[cfg(feature = "nightly")] -impl InteractionModel -where - T: crate::data_model::core::asynch::AsyncDataHandler, -{ - pub async fn handle_async<'a>(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - let mut transaction = Transaction::new(&mut ctx.exch_ctx); +pub enum Interaction<'a, 'r, 'p> { + Read { + req: ReadReq<'r>, + driver: ReadDriver<'a, 'r, 'p>, + }, + Write { + req: WriteReq<'r>, + driver: WriteDriver<'a, 'r, 'p>, + }, + Invoke { + req: InvReq<'r>, + driver: InvokeDriver<'a, 'r, 'p>, + }, + Subscribe { + req: SubscribeReq<'r>, + driver: SubscribeDriver<'a, 'r, 'p>, + }, +} - let reply = - if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { - self.0.handle(interaction, ctx.tx, &mut transaction).await? - } else { - true - }; +impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> { + pub async fn timeout( + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result, Error> { + let epoch = exchange.transport().matter().epoch; - if transaction.is_complete() { - transaction.exch_mut().close(); + let mut opcode: OpCode = rx.get_proto_opcode()?; + + let mut timeout = None; + + while opcode == OpCode::TimedRequest { + let rx_data = rx.as_slice(); + let req = TimedReq::from_tlv(&get_root_node_struct(rx_data)?)?; + + timeout = Some(req.tx_process(tx, epoch)?); + + exchange.exchange(tx, rx).await?; + + opcode = rx.get_proto_opcode()?; } - Ok(reply) - } -} - -impl InteractionHandler for InteractionModel -where - T: DataHandler, -{ - fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { - InteractionModel::handle(self, ctx) - } -} - -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::{ - data_model::core::asynch::AsyncDataHandler, error::Error, transport::proto_ctx::ProtoCtx, - }; - - use super::InteractionModel; - - pub trait AsyncInteractionHandler { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result; + Ok(timeout) } - impl AsyncInteractionHandler for &mut T + #[inline(always)] + pub fn new( + exchange: &'r mut Exchange<'a>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + rx_status: &'r mut Packet<'p>, + subscription_id: S, + timeout: Option, + ) -> Result, Error> where - T: AsyncInteractionHandler, + S: FnOnce() -> u32, { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - (**self).handle(ctx).await + let epoch = exchange.transport().matter().epoch; + + let opcode = rx.get_proto_opcode()?; + let rx_data = rx.as_slice(); + + match opcode { + OpCode::ReadRequest => { + let req = ReadReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = ReadDriver::new(exchange, tx, rx_status); + + Ok(Self::Read { req, driver }) + } + OpCode::WriteRequest => { + let req = WriteReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = WriteDriver::new(exchange, epoch, timeout, tx); + + Ok(Self::Write { req, driver }) + } + OpCode::InvokeRequest => { + let req = InvReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = InvokeDriver::new(exchange, epoch, timeout, tx); + + Ok(Self::Invoke { req, driver }) + } + OpCode::SubscribeRequest => { + let req = SubscribeReq::from_tlv(&get_root_node_struct(rx_data)?)?; + let driver = SubscribeDriver::new(exchange, subscription_id(), tx, rx_status); + + Ok(Self::Subscribe { req, driver }) + } + _ => { + error!("Opcode not handled: {:?}", opcode); + Err(ErrorCode::InvalidOpcode.into()) + } } } - impl AsyncInteractionHandler for InteractionModel - where - T: AsyncDataHandler, - { - async fn handle(&mut self, ctx: &mut ProtoCtx<'_, '_>) -> Result { - InteractionModel::handle_async(self, ctx).await - } + pub async fn start(&mut self) -> Result { + let started = match self { + Self::Read { req, driver } => { + driver.start(req)?; + true + } + Self::Write { req, driver } => driver.start(req).await?, + Self::Invoke { req, driver } => driver.start(req).await?, + Self::Subscribe { req, driver } => { + driver.start(req)?; + true + } + }; + + Ok(started) + } + + fn status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { + tx.reset(); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + let status = StatusResp { status }; + status.to_tlv(&mut tw, TagType::Anonymous) } } + +async fn exchange_confirm( + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + rx: &mut Packet<'_>, +) -> Result { + exchange.exchange(tx, rx).await?; + + let opcode: OpCode = rx.get_proto_opcode()?; + + if opcode == OpCode::StatusResponse { + let resp = StatusResp::from_tlv(&get_root_node_struct(rx.as_slice())?)?; + Ok(resp.status) + } else { + Interaction::status_response(tx, IMStatusCode::Busy)?; // TODO + + exchange.send_complete(tx).await?; + + Err(ErrorCode::Invalid.into()) // TODO + } +} + +fn has_timed_out(epoch: Epoch, timeout: Option) -> bool { + timeout.map(|timeout| epoch() > timeout).unwrap_or(false) +} diff --git a/matter/src/lib.rs b/matter/src/lib.rs index 1d7e5d4..b80a62c 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -69,6 +69,7 @@ //! Start off exploring by going to the [Matter] object. #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] #![cfg_attr(feature = "nightly", allow(incomplete_features))] pub mod acl; @@ -90,3 +91,22 @@ pub mod transport; pub mod utils; pub use crate::core::*; + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(feature = "alloc")] +#[macro_export] +macro_rules! alloc { + ($val:expr) => { + alloc::boxed::Box::new($val) + }; +} + +#[cfg(not(feature = "alloc"))] +#[macro_export] +macro_rules! alloc { + ($val:expr) => { + $val + }; +} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 63d5e56..28f4508 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -20,30 +20,25 @@ use core::cell::RefCell; use log::{error, trace}; use crate::{ + alloc, cert::Cert, crypto::{self, KeyPair, Sha256}, error::{Error, ErrorCode}, fabric::{Fabric, FabricMgr}, - secure_channel::common::SCStatusCodes, - secure_channel::common::{self, OpCode}, + secure_channel::common::{self, OpCode, PROTO_ID_SECURE_CHANNEL}, + secure_channel::common::{complete_with_status, SCStatusCodes}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ + exchange::Exchange, network::Address, - proto_ctx::ProtoCtx, + packet::Packet, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, utils::{rand::Rand, writebuf::WriteBuf}, }; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -enum State { - Sigma1Rx, - Sigma3Rx, -} - #[derive(Debug, Clone)] -pub struct CaseSession { - state: State, +struct CaseSession { peer_sessid: u16, local_sessid: u16, tt_hash: Sha256, @@ -54,11 +49,11 @@ pub struct CaseSession { } impl CaseSession { - pub fn new(peer_sessid: u16, local_sessid: u16) -> Result { + #[inline(always)] + pub fn new() -> Result { Ok(Self { - state: State::Sigma1Rx, - peer_sessid, - local_sessid, + peer_sessid: 0, + local_sessid: 0, tt_hash: Sha256::new()?, shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES], our_pub_key: [0; crypto::EC_POINT_LEN_BYTES], @@ -79,39 +74,50 @@ impl<'a> Case<'a> { Self { fabric_mgr, rand } } - pub fn casesigma3_handler( + pub async fn handle( &mut self, - ctx: &mut ProtoCtx, - ) -> Result<(bool, Option), Error> { - let mut case_session = ctx - .exch_ctx - .exch - .take_case_session() - .ok_or(ErrorCode::InvalidState)?; - if case_session.state != State::Sigma1Rx { - Err(ErrorCode::Invalid)?; - } - case_session.state = State::Sigma3Rx; + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + let mut session = alloc!(CaseSession::new()?); + + self.handle_casesigma1(exchange, rx, tx, &mut session) + .await?; + self.handle_casesigma3(exchange, rx, tx, &mut session).await + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_casesigma3( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + case_session: &mut CaseSession, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::CASESigma3 as _)?; let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { - common::create_sc_status_report( - ctx.tx, + drop(fabric_mgr); + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + ) + .await?; + return Ok(()); } // Safe to unwrap here let fabric = fabric.unwrap(); - let root = get_root_node_struct(ctx.rx.as_slice())?; + let root = get_root_node_struct(rx.as_slice())?; let encrypted = root.find_tag(1)?.slice()?; - let mut decrypted: [u8; 800] = [0; 800]; + let mut decrypted = alloc!([0; 800]); if encrypted.len() > decrypted.len() { error!("Data too large"); Err(ErrorCode::NoSpace)?; @@ -119,22 +125,29 @@ impl<'a> Case<'a> { let decrypted = &mut decrypted[..encrypted.len()]; decrypted.copy_from_slice(encrypted); - let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), &case_session, decrypted)?; + let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?; let decrypted = &decrypted[..len]; let root = get_root_node_struct(decrypted)?; let d = Sigma3Decrypt::from_tlv(&root)?; - let initiator_noc = Cert::new(d.initiator_noc.0)?; + let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); let mut initiator_icac = None; if let Some(icac) = d.initiator_icac { - initiator_icac = Some(Cert::new(icac.0)?); + initiator_icac = Some(alloc!(Cert::new(icac.0)?)); } - if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { + + #[cfg(feature = "alloc")] + let initiator_icac_mut = initiator_icac.as_deref(); + + #[cfg(not(feature = "alloc"))] + let initiator_icac_mut = initiator_icac.as_ref(); + + if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { error!("Certificate Chain doesn't match: {}", e); - common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) + .await?; + return Ok(()); } if Case::validate_sigma3_sign( @@ -142,39 +155,52 @@ impl<'a> Case<'a> { d.initiator_icac.map(|a| a.0), &initiator_noc, d.signature.0, - &case_session, + case_session, ) .is_err() { error!("Sigma3 Signature doesn't match"); - common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; - ctx.exch_ctx.exch.close(); - return Ok((true, None)); + complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) + .await?; + return Ok(()); } // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(ctx.rx.as_slice())?; + case_session.tt_hash.update(rx.as_slice())?; let clone_data = Case::get_session_clone_data( fabric.ipk.op_key(), fabric.get_node_id(), initiator_noc.get_node_id()?, - ctx.exch_ctx.sess.get_peer_addr(), - &case_session, + exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, + case_session, &peer_catids, )?; - common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; - ctx.exch_ctx.exch.clear_data(); - ctx.exch_ctx.exch.close(); - Ok((true, Some(clone_data))) + // TODO: Handle NoSpace + exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; + + complete_with_status( + exchange, + tx, + SCStatusCodes::SessionEstablishmentSuccess, + None, + ) + .await } - pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_casesigma1( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + case_session: &mut CaseSession, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::CASESigma1 as _)?; - let rx_buf = ctx.rx.as_slice(); + let rx_buf = rx.as_slice(); let root = get_root_node_struct(rx_buf)?; let r = Sigma1Req::from_tlv(&root)?; @@ -184,17 +210,20 @@ impl<'a> Case<'a> { .match_dest_id(r.initiator_random.0, r.dest_id.0); if local_fabric_idx.is_err() { error!("Fabric Index mismatch"); - common::create_sc_status_report( - ctx.tx, + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok(true); + ) + .await?; + + return Ok(()); } - let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); - let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?; + let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; + case_session.peer_sessid = r.initiator_sessid; + case_session.local_sessid = local_sessid; case_session.tt_hash.update(rx_buf)?; case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { @@ -225,52 +254,71 @@ impl<'a> Case<'a> { // Derive the Encrypted Part const MAX_ENCRYPTED_SIZE: usize = 800; - let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; + let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); let encrypted_len = { - let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; + let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { - common::create_sc_status_report( - ctx.tx, + drop(fabric_mgr); + complete_with_status( + exchange, + tx, common::SCStatusCodes::NoSharedTrustRoots, None, - )?; - ctx.exch_ctx.exch.close(); - return Ok(true); + ) + .await?; + return Ok(()); } + #[cfg(feature = "alloc")] + let signature_mut = &mut *signature; + + #[cfg(not(feature = "alloc"))] + let signature_mut = &mut signature; + let sign_len = Case::get_sigma2_sign( fabric.unwrap(), &case_session.our_pub_key, &case_session.peer_pub_key, - &mut signature, + signature_mut, )?; let signature = &signature[..sign_len]; + #[cfg(feature = "alloc")] + let encrypted_mut = &mut *encrypted; + + #[cfg(not(feature = "alloc"))] + let encrypted_mut = &mut encrypted; + Case::get_sigma2_encryption( fabric.unwrap(), self.rand, &our_random, - &mut case_session, + case_session, signature, - &mut encrypted, + encrypted_mut, )? }; let encrypted = &encrypted[0..encrypted_len]; // Generate our Response Body - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::CASESigma2 as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); tw.start_struct(TagType::Anonymous)?; tw.str8(TagType::Context(1), &our_random)?; tw.u16(TagType::Context(2), local_sessid)?; tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str16(TagType::Context(4), encrypted)?; tw.end_container()?; - case_session.tt_hash.update(ctx.tx.as_mut_slice())?; - ctx.exch_ctx.exch.set_case_session(case_session); - Ok(true) + + case_session.tt_hash.update(tx.as_mut_slice())?; + + exchange.exchange(tx, rx).await } fn get_session_clone_data( @@ -334,7 +382,7 @@ impl<'a> Case<'a> { Ok(()) } - fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option) -> Result<(), Error> { + fn validate_certs(fabric: &Fabric, noc: &Cert, icac: Option<&Cert>) -> Result<(), Error> { let mut verifier = noc.verify_chain_start(); if fabric.get_fabric_id() != noc.get_fabric_id()? { diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 80fb7b5..2f00ed4 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -17,7 +17,10 @@ use num_derive::FromPrimitive; -use crate::{error::Error, transport::packet::Packet}; +use crate::{ + error::Error, + transport::{exchange::Exchange, packet::Packet}, +}; use super::status_report::{create_status_report, GeneralCode}; @@ -51,6 +54,17 @@ pub enum SCStatusCodes { SessionNotFound = 5, } +pub async fn complete_with_status( + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + status_code: SCStatusCodes, + proto_data: Option<&[u8]>, +) -> Result<(), Error> { + create_sc_status_report(tx, status_code, proto_data)?; + + exchange.send_complete(tx).await +} + pub fn create_sc_status_report( proto_tx: &mut Packet, status_code: SCStatusCodes, diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 0ad17ed..b20ea9a 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -15,18 +15,19 @@ * limitations under the License. */ -use core::{borrow::Borrow, cell::RefCell}; +use core::borrow::Borrow; +use core::cell::RefCell; + +use log::error; use crate::{ error::*, fabric::FabricMgr, mdns::Mdns, - secure_channel::common::*, - tlv, - transport::{proto_ctx::ProtoCtx, session::CloneData}, + secure_channel::{common::*, pake::Pake}, + transport::{exchange::Exchange, packet::Packet}, utils::{epoch::Epoch, rand::Rand}, }; -use log::{error, info}; use super::{case::Case, pake::PaseMgr}; @@ -34,9 +35,10 @@ use super::{case::Case, pake::PaseMgr}; */ pub struct SecureChannel<'a> { - case: Case<'a>, pase: &'a RefCell, + fabric: &'a RefCell, mdns: &'a dyn Mdns, + rand: Rand, } impl<'a> SecureChannel<'a> { @@ -66,45 +68,34 @@ impl<'a> SecureChannel<'a> { rand: Rand, ) -> Self { Self { - case: Case::new(fabric, rand), + fabric, pase, mdns, + rand, } } - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { - let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?; - - ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - info!("Received Opcode: {:?}", proto_opcode); - info!("Received Data:"); - tlv::print_tlv_list(ctx.rx.as_slice()); - let (reply, clone_data) = match proto_opcode { - OpCode::MRPStandAloneAck => Ok((false, None)), - OpCode::PBKDFParamRequest => self - .pase - .borrow_mut() - .pbkdfparamreq_handler(ctx) - .map(|reply| (reply, None)), - OpCode::PASEPake1 => self - .pase - .borrow_mut() - .pasepake1_handler(ctx) - .map(|reply| (reply, None)), - OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns), - OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), - OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), - _ => { - error!("OpCode Not Handled: {:?}", proto_opcode); + pub async fn handle( + &self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + match rx.get_proto_opcode()? { + OpCode::PBKDFParamRequest => { + Pake::new(self.pase) + .handle(exchange, rx, tx, self.mdns) + .await + } + OpCode::CASESigma1 => { + Case::new(self.fabric, self.rand) + .handle(exchange, rx, tx) + .await + } + proto_opcode => { + error!("OpCode not handled: {:?}", proto_opcode); Err(ErrorCode::InvalidOpcode.into()) } - }?; - - if reply { - info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_mut_slice()); } - - Ok((reply, clone_data)) } } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 79f7d2c..ea2b98c 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -15,36 +15,35 @@ * limitations under the License. */ -use core::{fmt::Write, time::Duration}; +use core::{cell::RefCell, fmt::Write, time::Duration}; use super::{ - common::{create_sc_status_report, SCStatusCodes}, + common::{SCStatusCodes, PROTO_ID_SECURE_CHANNEL}, spake2p::{Spake2P, VerifierData}, }; use crate::{ - crypto, + alloc, crypto, error::{Error, ErrorCode}, mdns::{Mdns, ServiceMode}, - secure_channel::common::OpCode, + secure_channel::common::{complete_with_status, OpCode}, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ - exchange::ExchangeCtx, - network::Address, - proto_ctx::ProtoCtx, + exchange::{Exchange, ExchangeId}, + packet::Packet, session::{CloneData, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -#[allow(clippy::large_enum_variant)] -enum PaseMgrState { - Enabled(Pake, heapless::String<16>), - Disabled, +struct PaseSession { + mdns_service_name: heapless::String<16>, + verifier: VerifierData, } pub struct PaseMgr { - state: PaseMgrState, + session: Option, + timeout: Option, epoch: Epoch, rand: Rand, } @@ -53,14 +52,15 @@ impl PaseMgr { #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { - state: PaseMgrState::Disabled, + session: None, + timeout: None, epoch, rand, } } pub fn is_pase_session_enabled(&self) -> bool { - matches!(&self.state, PaseMgrState::Enabled(_, _)) + self.session.is_some() } pub fn enable_pase_session( @@ -80,62 +80,24 @@ impl PaseMgr { &mdns_service_name, ServiceMode::Commissionable(discriminator), )?; - self.state = PaseMgrState::Enabled( - Pake::new(verifier, self.epoch, self.rand), + + self.session = Some(PaseSession { mdns_service_name, - ); + verifier, + }); Ok(()) } pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { - if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { - mdns.remove(mdns_service_name)?; + if let Some(session) = self.session.as_ref() { + mdns.remove(&session.mdns_service_name)?; } - self.state = PaseMgrState::Disabled; + self.session = None; Ok(()) } - - /// If the PASE Session is enabled, execute the closure, - /// if not enabled, generate SC Status Report - fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result, Error> - where - F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, - { - if let PaseMgrState::Enabled(pake, _) = &mut self.state { - let data = f(pake, ctx)?; - - Ok(Some(data)) - } else { - error!("PASE Not enabled"); - create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?; - Ok(None) - } - } - - pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); - self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; - Ok(true) - } - - pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); - self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; - Ok(true) - } - - pub fn pasepake3_handler( - &mut self, - ctx: &mut ProtoCtx, - mdns: &dyn Mdns, - ) -> Result<(bool, Option), Error> { - let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; - self.disable_pase_session(mdns)?; - Ok((true, clone_data.flatten())) - } } // This file basically deals with the handlers for the PASE secure channel protocol @@ -147,96 +109,65 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60); const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; -struct SessionData { +struct Timeout { start_time: Duration, - exch_id: u16, - peer_addr: Address, - spake2p: Spake2P, + exch_id: ExchangeId, } -impl SessionData { - fn is_sess_expired(&self, epoch: Epoch) -> Result { - Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS) - } -} - -#[allow(clippy::large_enum_variant)] -enum PakeState { - Idle, - InProgress(SessionData), -} - -impl PakeState { - const fn new() -> Self { - Self::Idle - } - - fn take(&mut self) -> Result { - let new = core::mem::replace(self, PakeState::Idle); - if let PakeState::InProgress(s) = new { - Ok(s) - } else { - Err(ErrorCode::InvalidSignature.into()) - } - } - - fn is_idle(&self) -> bool { - core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle) - } - - fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { - let sd = self.take()?; - if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { - Err(ErrorCode::InvalidState.into()) - } else { - Ok(sd) - } - } - - fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) { - *self = PakeState::InProgress(SessionData { - start_time: epoch(), - spake2p, - exch_id: exch_ctx.exch.get_id(), - peer_addr: exch_ctx.sess.get_peer_addr(), - }); - } - - fn set_sess_data(&mut self, sd: SessionData) { - *self = PakeState::InProgress(sd); - } -} - -impl Default for PakeState { - fn default() -> Self { - Self::new() - } -} - -struct Pake { - verifier: VerifierData, - state: PakeState, - epoch: Epoch, - rand: Rand, -} - -impl Pake { - pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self { - // TODO: Can any PBKDF2 calculation be pre-computed here +impl Timeout { + fn new(exchange: &Exchange, epoch: Epoch) -> Self { Self { - verifier, - state: PakeState::new(), - epoch, - rand, + start_time: epoch(), + exch_id: exchange.id().clone(), } } + fn is_sess_expired(&self, epoch: Epoch) -> bool { + epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS + } +} + +pub struct Pake<'a> { + pase: &'a RefCell, +} + +impl<'a> Pake<'a> { + pub const fn new(pase: &'a RefCell) -> Self { + // TODO: Can any PBKDF2 calculation be pre-computed here + Self { pase } + } + + pub async fn handle( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + mdns: &dyn Mdns, + ) -> Result<(), Error> { + let mut spake2p = alloc!(Spake2P::new()); + + self.handle_pbkdfparamrequest(exchange, rx, tx, &mut spake2p) + .await?; + self.handle_pasepake1(exchange, rx, tx, &mut spake2p) + .await?; + self.handle_pasepake3(exchange, rx, tx, mdns, &mut spake2p) + .await + } + #[allow(non_snake_case)] - pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result, Error> { - let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; + async fn handle_pasepake3( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + mdns: &dyn Mdns, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PASEPake3 as _)?; + self.update_timeout(exchange, tx, true).await?; - let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; - let (status_code, ke) = sd.spake2p.handle_cA(cA); + let cA = extract_pasepake_1_or_3_params(rx.as_slice())?; + let (status_code, ke) = spake2p.handle_cA(cA); let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys @@ -246,7 +177,7 @@ impl Pake { .map_err(|_x| ErrorCode::NoSpace)?; // Create a session - let data = sd.spake2p.get_app_data(); + let data = spake2p.get_app_data(); let peer_sessid: u16 = (data & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; let mut clone_data = CloneData::new( @@ -254,7 +185,7 @@ impl Pake { 0, peer_sessid, local_sessid, - ctx.exch_ctx.sess.get_peer_addr(), + exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, SessionMode::Pase, ); clone_data.dec_key.copy_from_slice(&session_keys[0..16]); @@ -269,48 +200,70 @@ impl Pake { None }; - create_sc_status_report(ctx.tx, status_code, None)?; - ctx.exch_ctx.exch.close(); - Ok(clone_data) + if let Some(clone_data) = clone_data { + // TODO: Handle NoSpace + exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; + + self.pase.borrow_mut().disable_pase_session(mdns)?; + } + + complete_with_status(exchange, tx, status_code, None).await?; + + Ok(()) } #[allow(non_snake_case)] - pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { - let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_pasepake1( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PASEPake1 as _)?; + self.update_timeout(exchange, tx, false).await?; - let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; + + let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; - sd.spake2p.start_verifier(&self.verifier)?; - sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?; + spake2p.start_verifier(&session.verifier)?; + spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + // Generate response + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PASEPake2 as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); let resp = Pake1Resp { pb: OctetStr(&pB), cb: OctetStr(&cB), }; resp.to_tlv(&mut tw, TagType::Anonymous)?; - self.state.set_sess_data(sd); - - Ok(()) + drop(pase); + exchange.exchange(tx, rx).await } - pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { - if !self.state.is_idle() { - let sd = self.state.take()?; - if sd.is_sess_expired(self.epoch)? { - info!("Previous session expired, clearing it"); - self.state = PakeState::Idle; - } else { - info!("Previous session in-progress, denying new request"); - // little-endian timeout (here we've hardcoded 500ms) - create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; - return Ok(()); - } - } + #[allow(clippy::await_holding_refcell_ref)] + async fn handle_pbkdfparamrequest( + &mut self, + exchange: &mut Exchange<'_>, + rx: &mut Packet<'_>, + tx: &mut Packet<'_>, + spake2p: &mut Spake2P, + ) -> Result<(), Error> { + rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?; + self.update_timeout(exchange, tx, true).await?; - let root = tlv::get_root_node(ctx.rx.as_slice())?; + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; + + let root = tlv::get_root_node(rx.as_slice())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); @@ -318,15 +271,18 @@ impl Pake { } let mut our_random: [u8; 32] = [0; 32]; - (self.rand)(&mut our_random); + (self.pase.borrow().rand)(&mut our_random); - let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); + let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; - let mut spake2p = Spake2P::new(); spake2p.set_app_data(spake2p_data); // Generate response - let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); let mut resp = PBKDFParamResp { init_random: a.initiator_random, our_random: OctetStr(&our_random), @@ -335,18 +291,76 @@ impl Pake { }; if !a.has_params { let params_resp = PBKDFParamRespParams { - count: self.verifier.count, - salt: OctetStr(&self.verifier.salt), + count: session.verifier.count, + salt: OctetStr(&session.verifier.salt), }; resp.params = Some(params_resp); } resp.to_tlv(&mut tw, TagType::Anonymous)?; - spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; - self.state - .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); + spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?; - Ok(()) + drop(pase); + + exchange.exchange(tx, rx).await + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn update_timeout( + &mut self, + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + new: bool, + ) -> Result<(), Error> { + self.check_session(exchange, tx).await?; + + let mut pase = self.pase.borrow_mut(); + + if pase + .timeout + .as_ref() + .map(|sd| sd.is_sess_expired(pase.epoch)) + .unwrap_or(false) + { + pase.timeout = None; + } + + let status = if let Some(sd) = pase.timeout.as_mut() { + if &sd.exch_id != exchange.id() { + info!("Other PAKE session in progress"); + Some(SCStatusCodes::Busy) + } else { + None + } + } else if new { + None + } else { + error!("PAKE session not found or expired"); + Some(SCStatusCodes::SessionNotFound) + }; + + if let Some(status) = status { + drop(pase); + + complete_with_status(exchange, tx, status, None).await + } else { + pase.timeout = Some(Timeout::new(exchange, pase.epoch)); + + Ok(()) + } + } + + async fn check_session( + &mut self, + exchange: &mut Exchange<'_>, + tx: &mut Packet<'_>, + ) -> Result<(), Error> { + if self.pase.borrow().session.is_none() { + error!("PASE not enabled"); + complete_with_status(exchange, tx, SCStatusCodes::InvalidParameter, None).await + } else { + Ok(()) + } } } diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 1b169ee..2a54b4a 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -15,189 +15,51 @@ * limitations under the License. */ +use core::{borrow::Borrow, cell::RefCell}; + +use crate::{error::ErrorCode, secure_channel::common::OpCode, Matter}; +use embassy_futures::select::select; +use embassy_time::{Duration, Timer}; use log::info; -use crate::{error::*, CommissioningData, Matter}; +use crate::{ + error::Error, secure_channel::common::PROTO_ID_SECURE_CHANNEL, transport::packet::Packet, +}; -use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; -use crate::secure_channel::core::SecureChannel; -use crate::transport::mrp::ReliableMessage; -use crate::transport::{exchange, network::Address, packet::Packet}; +use super::{ + exchange::{ + Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Notification, Role, + MAX_EXCHANGES, + }, + mrp::ReliableMessage, + session::SessionMgr, +}; -use super::proto_ctx::ProtoCtx; -use super::session::CloneData; - -enum RecvState { - New, - OpenExchange, - AddSession(CloneData), - EvictSession, - EvictSession2(CloneData), - Ack, +#[derive(Debug)] +enum OpCodeDescriptor { + SecureChannel(OpCode), + InteractionModel(crate::interaction_model::core::OpCode), + Unknown(u8), } -pub enum RecvAction<'r, 'p> { - Send(Address, &'r [u8]), - Interact(ProtoCtx<'r, 'p>), -} - -pub struct RecvCompletion<'r, 'a> { - transport: &'r mut Transport<'a>, - rx: Packet<'r>, - tx: Packet<'r>, - state: RecvState, -} - -impl<'r, 'a> RecvCompletion<'r, 'a> { - pub fn next_action(&mut self) -> Result>, Error> { - loop { - // Polonius will remove the need for unsafe one day - let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; - - if let Some(action) = this.maybe_next_action()? { - return Ok(action); - } +impl From for OpCodeDescriptor { + fn from(value: u8) -> Self { + if let Some(opcode) = num::FromPrimitive::from_u8(value) { + Self::SecureChannel(opcode) + } else if let Some(opcode) = num::FromPrimitive::from_u8(value) { + Self::InteractionModel(opcode) + } else { + Self::Unknown(value) } } - - fn maybe_next_action(&mut self) -> Result>>, Error> { - self.transport.exch_mgr.purge(); - self.tx.reset(); - - let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { - RecvState::New => { - self.rx.plain_hdr_decode()?; - (RecvState::OpenExchange, None) - } - RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) { - Ok(Some(exch_ctx)) => { - if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { - let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - - let mut secure_channel = SecureChannel::new(self.transport.matter); - - let (reply, clone_data) = secure_channel.handle(&mut proto_ctx)?; - - let state = if let Some(clone_data) = clone_data { - RecvState::AddSession(clone_data) - } else { - RecvState::Ack - }; - - if reply { - if proto_ctx.send()? { - ( - state, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (state, None) - } - } else { - (state, None) - } - } else { - let proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx); - - (RecvState::Ack, Some(Some(RecvAction::Interact(proto_ctx)))) - } - } - Ok(None) => (RecvState::Ack, None), - Err(e) => match e.code() { - ErrorCode::Duplicate => (RecvState::Ack, None), - ErrorCode::NoSpace => (RecvState::EvictSession, None), - _ => Err(e)?, - }, - }, - RecvState::AddSession(clone_data) => { - match self.transport.exch_mgr.add_session(&clone_data) { - Ok(_) => (RecvState::Ack, None), - Err(e) => match e.code() { - ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), - _ => Err(e)?, - }, - } - } - RecvState::EvictSession => { - if self.transport.exch_mgr.evict_session(&mut self.tx)? { - ( - RecvState::OpenExchange, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::EvictSession, None) - } - } - RecvState::EvictSession2(clone_data) => { - if self.transport.exch_mgr.evict_session(&mut self.tx)? { - ( - RecvState::AddSession(clone_data), - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::EvictSession2(clone_data), None) - } - } - RecvState::Ack => { - if let Some(exch_id) = self.transport.exch_mgr.pending_ack() { - info!("Sending MRP Standalone ACK for exch {}", exch_id); - - ReliableMessage::prepare_ack(exch_id, &mut self.tx); - - if self.transport.exch_mgr.send(exch_id, &mut self.tx)? { - ( - RecvState::Ack, - Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))), - ) - } else { - (RecvState::Ack, None) - } - } else { - (RecvState::Ack, Some(None)) - } - } - }; - - self.state = state; - Ok(next) - } -} - -enum NotifyState {} - -pub enum NotifyAction<'r, 'p> { - Send(&'r [u8]), - Notify(ProtoCtx<'r, 'p>), -} - -pub struct NotifyCompletion<'r, 'a> { - // TODO - _transport: &'r mut Transport<'a>, - _rx: Packet<'r>, - _tx: Packet<'r>, - _state: NotifyState, -} - -impl<'r, 'a> NotifyCompletion<'r, 'a> { - pub fn next_action(&mut self) -> Result>, Error> { - loop { - // Polonius will remove the need for unsafe one day - let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; - - if let Some(action) = this.maybe_next_action()? { - return Ok(action); - } - } - } - - fn maybe_next_action(&mut self) -> Result>>, Error> { - Ok(Some(None)) // TODO: Future - } } pub struct Transport<'a> { matter: &'a Matter<'a>, - exch_mgr: exchange::ExchangeMgr, + pub(crate) exchanges: RefCell>, + pub(crate) send_notification: Notification, + pub(crate) persist_notification: Notification, + pub session_mgr: RefCell, } impl<'a> Transport<'a> { @@ -208,44 +70,358 @@ impl<'a> Transport<'a> { Self { matter, - exch_mgr: exchange::ExchangeMgr::new(epoch, rand), + exchanges: RefCell::new(heapless::Vec::new()), + send_notification: Notification::new(), + persist_notification: Notification::new(), + session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), } } - pub fn matter(&self) -> &Matter<'a> { + pub fn matter(&self) -> &'a Matter<'a> { self.matter } - pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { - info!("Starting Matter transport"); + pub async fn initiate(&self, _fabric_id: u64, _node_id: u64) -> Result, Error> { + unimplemented!() + } - if self.matter().start_comissioning(dev_comm, buf)? { - info!("Comissioning started"); + pub fn process_rx<'r>( + &'r self, + construction_notification: &'r Notification, + src_rx: &mut Packet<'_>, + ) -> Result>, Error> { + self.purge()?; + + let mut exchanges = self.exchanges.borrow_mut(); + let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) { + Ok((ctx, new)) => (ctx, new), + Err(e) => match e.code() { + ErrorCode::Duplicate => { + self.send_notification.signal(()); + return Ok(None); + } + _ => Err(e)?, + }, + }; + + src_rx.log("Got packet"); + + if src_rx.proto.is_ack() { + if new { + Err(ErrorCode::Invalid)?; + } else { + let state = &mut ctx.state; + + match state { + ExchangeState::ExchangeRecv { + tx_acknowledged, .. + } => { + *tx_acknowledged = true; + } + ExchangeState::CompleteAcknowledge { notification, .. } => { + unsafe { notification.as_ref() }.unwrap().signal(()); + ctx.state = ExchangeState::Closed; + } + _ => { + // TODO: Error handling + todo!() + } + } + + self.notify_changed(); + } + } + + if new { + let constructor = ExchangeCtr { + exchange: Exchange { + id: ctx.id.clone(), + transport: self, + notification: Notification::new(), + }, + construction_notification, + }; + + self.notify_changed(); + + Ok(Some(constructor)) + } else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL + && src_rx.proto.proto_opcode == OpCode::MRPStandAloneAck as u8 + { + // Standalone ack, do nothing + Ok(None) + } else { + let state = &mut ctx.state; + + match state { + ExchangeState::ExchangeRecv { + rx, notification, .. + } => { + let rx = unsafe { rx.as_mut() }.unwrap(); + rx.load(src_rx)?; + + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; + } + _ => { + // TODO: Error handling + todo!() + } + } + + self.notify_changed(); + + Ok(None) + } + } + + pub async fn wait_construction( + &self, + construction_notification: &Notification, + src_rx: &Packet<'_>, + exchange_id: &ExchangeId, + ) -> Result<(), Error> { + construction_notification.wait().await; + + let mut exchanges = self.exchanges.borrow_mut(); + + let ctx = Self::get(&mut exchanges, exchange_id).unwrap(); + + let state = &mut ctx.state; + + match state { + ExchangeState::Construction { rx, notification } => { + let rx = unsafe { rx.as_mut() }.unwrap(); + rx.load(src_rx)?; + + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; + } + _ => unreachable!(), } Ok(()) } - pub fn recv<'r>( - &'r mut self, - addr: Address, - rx_buf: &'r mut [u8], - tx_buf: &'r mut [u8], - ) -> RecvCompletion<'r, 'a> { - let mut rx = Packet::new_rx(rx_buf); - let tx = Packet::new_tx(tx_buf); + pub async fn wait_tx(&self) -> Result<(), Error> { + select( + self.send_notification.wait(), + Timer::after(Duration::from_millis(100)), + ) + .await; - rx.peer = addr; + Ok(()) + } - RecvCompletion { - transport: self, - rx, - tx, - state: RecvState::New, + pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result { + self.purge()?; + + let mut exchanges = self.exchanges.borrow_mut(); + + let ctx = exchanges.iter_mut().find(|ctx| { + matches!( + &ctx.state, + ExchangeState::Acknowledge { .. } + | ExchangeState::ExchangeSend { .. } + // | ExchangeState::ExchangeRecv { + // tx_acknowledged: false, + // .. + // } + | ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. } + ) || ctx.mrp.is_ack_ready(*self.matter.borrow()) + }); + + if let Some(ctx) = ctx { + self.notify_changed(); + + let state = &mut ctx.state; + + let send = match state { + ExchangeState::Acknowledge { notification } => { + ReliableMessage::prepare_ack(ctx.id.id, dest_tx); + + unsafe { notification.as_ref() }.unwrap().signal(()); + *state = ExchangeState::Active; + + true + } + ExchangeState::ExchangeSend { + tx, + rx, + notification, + } => { + let tx = unsafe { tx.as_ref() }.unwrap(); + dest_tx.load(tx)?; + + *state = ExchangeState::ExchangeRecv { + _tx: tx, + tx_acknowledged: false, + rx: *rx, + notification: *notification, + }; + + true + } + // ExchangeState::ExchangeRecv { .. } => { + // // TODO: Re-send the tx package if due + // false + // } + ExchangeState::Complete { tx, notification } => { + let tx = unsafe { tx.as_ref() }.unwrap(); + dest_tx.load(tx)?; + + *state = ExchangeState::CompleteAcknowledge { + _tx: tx as *const _, + notification: *notification, + }; + + true + } + // ExchangeState::CompleteAcknowledge { .. } => { + // // TODO: Re-send the tx package if due + // false + // } + _ => { + ReliableMessage::prepare_ack(ctx.id.id, dest_tx); + true + } + }; + + if send { + dest_tx.log("Sending packet"); + + self.pre_send(ctx, dest_tx)?; + self.notify_changed(); + + return Ok(true); + } + } + + Ok(false) + } + + fn purge(&self) -> Result<(), Error> { + loop { + let mut exchanges = self.exchanges.borrow_mut(); + + if let Some(index) = exchanges.iter_mut().enumerate().find_map(|(index, ctx)| { + matches!(ctx.state, ExchangeState::Closed).then_some(index) + }) { + exchanges.swap_remove(index); + } else { + break; + } + } + + Ok(()) + } + + fn post_recv<'r>( + &self, + exchanges: &'r mut heapless::Vec, + rx: &mut Packet<'_>, + ) -> Result<(&'r mut ExchangeCtx, bool), Error> { + rx.plain_hdr_decode()?; + + // Get the session + + let mut session_mgr = self.session_mgr.borrow_mut(); + + let sess_index = session_mgr.post_recv(rx)?; + let session = session_mgr.mut_by_index(sess_index).unwrap(); + + // Decrypt the message + session.recv(self.matter.epoch, rx)?; + + // Get the exchange + // TODO: Handle out of space + let (exch, new) = Self::register( + exchanges, + ExchangeId::load(rx), + Role::complementary(rx.proto.is_initiator()), + // We create a new exchange, only if the peer is the initiator + rx.proto.is_initiator(), + )?; + + // Message Reliability Protocol + exch.mrp.recv(rx, self.matter.epoch)?; + + Ok((exch, new)) + } + + fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> { + let mut session_mgr = self.session_mgr.borrow_mut(); + let sess_index = session_mgr + .get( + ctx.id.session_id.id, + ctx.id.session_id.peer_addr, + ctx.id.session_id.peer_nodeid, + ctx.id.session_id.is_encrypted, + ) + .ok_or(ErrorCode::NoSession)?; + + let session = session_mgr.mut_by_index(sess_index).unwrap(); + + tx.proto.exch_id = ctx.id.id; + if ctx.role == Role::Initiator { + tx.proto.set_initiator(); + } + + session.pre_send(tx)?; + ctx.mrp.pre_send(tx)?; + session_mgr.send(sess_index, tx) + } + + fn register( + exchanges: &mut heapless::Vec, + id: ExchangeId, + role: Role, + create_new: bool, + ) -> Result<(&mut ExchangeCtx, bool), Error> { + let exchange_index = exchanges + .iter_mut() + .enumerate() + .find_map(|(index, exchange)| (exchange.id == id).then_some(index)); + + if let Some(exchange_index) = exchange_index { + let exchange = &mut exchanges[exchange_index]; + if exchange.role == role { + Ok((exchange, false)) + } else { + Err(ErrorCode::NoExchange.into()) + } + } else if create_new { + info!("Creating new exchange: {:?}", id); + + let exchange = ExchangeCtx { + id, + role, + mrp: ReliableMessage::new(), + state: ExchangeState::Active, + }; + + exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?; + + Ok((exchanges.iter_mut().next_back().unwrap(), true)) + } else { + Err(ErrorCode::NoExchange.into()) } } - pub fn notify(&mut self, _tx: &mut Packet) -> Result { - Ok(false) + pub(crate) fn get<'r>( + exchanges: &'r mut heapless::Vec, + id: &ExchangeId, + ) -> Option<&'r mut ExchangeCtx> { + exchanges.iter_mut().find(|exchange| exchange.id == *id) + } + + pub fn notify_changed(&self) { + if self.matter().is_changed() { + self.persist_notification.signal(()); + } + } + + pub async fn wait_changed(&self) { + self.persist_notification.wait().await } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 5dbb1bb..fbe3d7a 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -1,625 +1,320 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use core::fmt; -use core::time::Duration; -use log::{error, info, trace}; -use owo_colors::OwoColorize; +use crate::{ + acl::Accessor, + error::{Error, ErrorCode}, + Matter, +}; -use crate::error::{Error, ErrorCode}; -use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; -use crate::secure_channel; -use crate::secure_channel::case::CaseSession; -use crate::utils::epoch::Epoch; -use crate::utils::rand::Rand; +use super::{ + core::Transport, + mrp::ReliableMessage, + network::Address, + packet::Packet, + session::{Session, SessionMgr}, +}; -use heapless::LinearMap; +pub const MAX_EXCHANGES: usize = 8; -use super::session::CloneData; -use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; - -pub struct ExchangeCtx<'a> { - pub exch: &'a mut Exchange, - pub sess: SessionHandle<'a>, - pub epoch: Epoch, -} - -impl<'a> ExchangeCtx<'a> { - pub fn send(&mut self, tx: &mut Packet) -> Result { - self.exch.send(tx, &mut self.sess) - } -} +pub type Notification = embassy_sync::signal::Signal; #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] -pub enum Role { +pub(crate) enum Role { #[default] Initiator = 0, Responder = 1, } -#[derive(Debug, PartialEq, Default)] -enum State { - /// The exchange is open and active - #[default] - Open, - /// The exchange is closed, but keys are active since retransmissions/acks may be pending - Close, - /// The exchange is terminated, keys are destroyed, no communication can happen - Terminate, -} - -// Instead of just doing an Option<>, we create some special handling -// where the commonly used higher layer data store does't have to do a Box -#[derive(Default)] -pub enum DataOption { - CaseSession(CaseSession), - Time(Duration), - SuspendedReadReq(ResumeReadReq), - SubscriptionId(u32), - SuspendedSubscibeReq(ResumeSubscribeReq), - #[default] - None, -} - -#[derive(Default)] -pub struct Exchange { - id: u16, - sess_idx: usize, - role: Role, - state: State, - mrp: ReliableMessage, - // Currently I see this primarily used in PASE and CASE. If that is the limited use - // of this, we might move this into a separate data structure, so as not to burden - // all 'exchanges'. - data: DataOption, -} - -impl Exchange { - pub fn new(id: u16, sess_idx: usize, role: Role) -> Exchange { - Exchange { - id, - sess_idx, - role, - state: State::Open, - mrp: ReliableMessage::new(), - ..Default::default() - } - } - - pub fn terminate(&mut self) { - self.data = DataOption::None; - self.state = State::Terminate; - } - - pub fn close(&mut self) { - self.data = DataOption::None; - self.state = State::Close; - } - - pub fn is_state_open(&self) -> bool { - self.state == State::Open - } - - pub fn is_purgeable(&self) -> bool { - // No Users, No pending ACKs/Retrans - self.state == State::Terminate || (self.state == State::Close && self.mrp.is_empty()) - } - - pub fn get_id(&self) -> u16 { - self.id - } - - pub fn get_role(&self) -> Role { - self.role - } - - pub fn clear_data(&mut self) { - self.data = DataOption::None; - } - - pub fn set_case_session(&mut self, session: CaseSession) { - self.data = DataOption::CaseSession(session); - } - - pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { - if let DataOption::CaseSession(session) = &mut self.data { - Some(session) +impl Role { + pub fn complementary(is_initiator: bool) -> Self { + if is_initiator { + Self::Responder } else { - None + Self::Initiator } } - - pub fn take_case_session(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::CaseSession(session) = old { - Some(session) - } else { - self.data = old; - None - } - } - - pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) { - self.data = DataOption::SuspendedReadReq(req); - } - - pub fn take_suspended_read_req(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SuspendedReadReq(req) = old { - Some(req) - } else { - self.data = old; - None - } - } - - pub fn set_subscription_id(&mut self, id: u32) { - self.data = DataOption::SubscriptionId(id); - } - - pub fn take_subscription_id(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SubscriptionId(id) = old { - Some(id) - } else { - self.data = old; - None - } - } - - pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) { - self.data = DataOption::SuspendedSubscibeReq(req); - } - - pub fn take_suspended_subscribe_req(&mut self) -> Option { - let old = core::mem::replace(&mut self.data, DataOption::None); - if let DataOption::SuspendedSubscibeReq(req) = old { - Some(req) - } else { - self.data = old; - None - } - } - - pub fn set_data_time(&mut self, expiry_ts: Option) { - if let Some(t) = expiry_ts { - self.data = DataOption::Time(t); - } - } - - pub fn get_data_time(&self) -> Option { - match self.data { - DataOption::Time(t) => Some(t), - _ => None, - } - } - - pub(crate) fn send( - &mut self, - tx: &mut Packet, - session: &mut SessionHandle, - ) -> Result { - if self.state == State::Terminate { - info!("Skipping tx for terminated exchange {}", self.id); - return Ok(false); - } - - trace!("payload: {:x?}", tx.as_slice()); - info!( - "{} with proto id: {} opcode: {}, tlv:\n", - "Sending".blue(), - tx.get_proto_id(), - tx.get_proto_raw_opcode(), - ); - - //print_tlv_list(tx.as_slice()); - - tx.proto.exch_id = self.id; - if self.role == Role::Initiator { - tx.proto.set_initiator(); - } - - session.pre_send(tx)?; - self.mrp.pre_send(tx)?; - session.send(tx)?; - - Ok(true) - } } -impl fmt::Display for Exchange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}", - self.id, self.sess_idx, self.role, self.mrp, self.state - ) +#[derive(Debug)] +pub(crate) struct ExchangeCtx { + pub(crate) id: ExchangeId, + pub(crate) role: Role, + pub(crate) mrp: ReliableMessage, + pub(crate) state: ExchangeState, +} + +#[derive(Debug, Clone)] +pub(crate) enum ExchangeState { + Construction { + rx: *mut Packet<'static>, + notification: *const Notification, + }, + Active, + Acknowledge { + notification: *const Notification, + }, + ExchangeSend { + tx: *const Packet<'static>, + rx: *mut Packet<'static>, + notification: *const Notification, + }, + ExchangeRecv { + _tx: *const Packet<'static>, + tx_acknowledged: bool, + rx: *mut Packet<'static>, + notification: *const Notification, + }, + Complete { + tx: *const Packet<'static>, + notification: *const Notification, + }, + CompleteAcknowledge { + _tx: *const Packet<'static>, + notification: *const Notification, + }, + Closed, +} + +pub struct ExchangeCtr<'a> { + pub(crate) exchange: Exchange<'a>, + pub(crate) construction_notification: &'a Notification, +} + +impl<'a> ExchangeCtr<'a> { + pub const fn id(&self) -> &ExchangeId { + self.exchange.id() + } + + pub async fn get(mut self, rx: &mut Packet<'_>) -> Result, Error> { + let construction_notification = self.construction_notification; + + self.exchange.with_ctx_mut(move |exchange, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; + } + + let rx: &'static mut Packet<'static> = unsafe { core::mem::transmute(rx) }; + let notification: &'static Notification = + unsafe { core::mem::transmute(&exchange.notification) }; + + ctx.state = ExchangeState::Construction { rx, notification }; + + construction_notification.signal(()); + + Ok(()) + })?; + + self.exchange.notification.wait().await; + + Ok(self.exchange) } } -pub fn get_role(is_initiator: bool) -> Role { - if is_initiator { - Role::Initiator - } else { - Role::Responder - } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ExchangeId { + pub id: u16, + pub session_id: SessionId, } -pub fn get_complementary_role(is_initiator: bool) -> Role { - if is_initiator { - Role::Responder - } else { - Role::Initiator - } -} - -const MAX_EXCHANGES: usize = 8; - -pub struct ExchangeMgr { - // keys: exch-id - exchanges: LinearMap, - sess_mgr: SessionMgr, - epoch: Epoch, -} - -pub const MAX_MRP_ENTRIES: usize = 4; - -impl ExchangeMgr { - #[inline(always)] - pub fn new(epoch: Epoch, rand: Rand) -> Self { +impl ExchangeId { + pub fn load(rx: &Packet) -> Self { Self { - sess_mgr: SessionMgr::new(epoch, rand), - exchanges: LinearMap::new(), - epoch, + id: rx.proto.exch_id, + session_id: SessionId::load(rx), } } +} +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct SessionId { + pub id: u16, + pub peer_addr: Address, + pub peer_nodeid: Option, + pub is_encrypted: bool, +} - pub fn get_sess_mgr(&mut self) -> &mut SessionMgr { - &mut self.sess_mgr +impl SessionId { + pub fn load(rx: &Packet) -> Self { + Self { + id: rx.plain.sess_id, + peer_addr: rx.peer, + peer_nodeid: rx.plain.get_src_u64(), + is_encrypted: rx.plain.is_encrypted(), + } + } +} +pub struct Exchange<'a> { + pub(crate) id: ExchangeId, + pub(crate) transport: &'a Transport<'a>, + pub(crate) notification: Notification, +} + +impl<'a> Exchange<'a> { + pub const fn id(&self) -> &ExchangeId { + &self.id } - pub fn _get_with_id( - exchanges: &mut LinearMap, - exch_id: u16, - ) -> Option<&mut Exchange> { - exchanges.get_mut(&exch_id) + pub fn matter(&self) -> &Matter<'a> { + self.transport.matter() } - pub fn get_with_id(&mut self, exch_id: u16) -> Option<&mut Exchange> { - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id) + pub fn transport(&self) -> &Transport<'a> { + self.transport } - fn _get( - exchanges: &mut LinearMap, - sess_idx: usize, - id: u16, - role: Role, - create_new: bool, - ) -> Result<&mut Exchange, Error> { - // I don't prefer that we scan the list twice here (once for contains_key and other) - if !exchanges.contains_key(&(id)) { - if create_new { - // If an exchange doesn't exist, create a new one - info!("Creating new exchange"); - let e = Exchange::new(id, sess_idx, role); - if exchanges.insert(id, e).is_err() { - Err(ErrorCode::NoSpace)?; - } + pub fn accessor(&self) -> Result, Error> { + self.with_session(|sess| { + Ok(Accessor::for_session( + sess, + &self.transport.matter().acl_mgr, + )) + }) + } + + pub fn with_session_mut(&self, f: F) -> Result + where + F: FnOnce(&mut Session) -> Result, + { + self.with_ctx(|_self, ctx| { + let mut session_mgr = _self.transport.session_mgr.borrow_mut(); + + let sess_index = session_mgr + .get( + ctx.id.session_id.id, + ctx.id.session_id.peer_addr, + ctx.id.session_id.peer_nodeid, + ctx.id.session_id.is_encrypted, + ) + .ok_or(ErrorCode::NoSession)?; + + f(session_mgr.mut_by_index(sess_index).unwrap()) + }) + } + + pub fn with_session(&self, f: F) -> Result + where + F: FnOnce(&Session) -> Result, + { + self.with_session_mut(|sess| f(sess)) + } + + pub fn with_session_mgr_mut(&self, f: F) -> Result + where + F: FnOnce(&mut SessionMgr) -> Result, + { + let mut session_mgr = self.transport.session_mgr.borrow_mut(); + + f(&mut session_mgr) + } + + pub async fn initiate(&mut self, fabric_id: u64, node_id: u64) -> Result, Error> { + self.transport.initiate(fabric_id, node_id).await + } + + pub async fn acknowledge(&mut self) -> Result<(), Error> { + let wait = self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; + } + + if ctx.mrp.is_empty() { + Ok(false) } else { - Err(ErrorCode::NoSpace)?; + ctx.state = ExchangeState::Acknowledge { + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); + + Ok(true) } + })?; + + if wait { + self.notification.wait().await; } - // At this point, we would either have inserted the record if 'create_new' was set - // or it existed already - if let Some(result) = exchanges.get_mut(&id) { - if result.get_role() == role && sess_idx == result.sess_idx { - Ok(result) - } else { - Err(ErrorCode::NoExchange.into()) - } - } else { - error!("This should never happen"); - Err(ErrorCode::NoSpace.into()) - } + Ok(()) } - /// The Exchange Mgr receive is like a big processing function - pub fn recv(&mut self, rx: &mut Packet) -> Result, Error> { - // Get the session - let index = self.sess_mgr.post_recv(rx)?; - let mut session = self.sess_mgr.get_session_handle(index); + pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> { + let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; + let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) }; - // Decrypt the message - session.recv(self.epoch, rx)?; - - // Get the exchange - let exch = ExchangeMgr::_get( - &mut self.exchanges, - index, - rx.proto.exch_id, - get_complementary_role(rx.proto.is_initiator()), - // We create a new exchange, only if the peer is the initiator - rx.proto.is_initiator(), - )?; - - // Message Reliability Protocol - exch.mrp.recv(rx, self.epoch)?; - - if exch.is_state_open() { - Ok(Some(ExchangeCtx { - exch, - sess: session, - epoch: self.epoch, - })) - } else { - // Instead of an error, we send None here, because it is likely that - // we just processed an acknowledgement that cleared the exchange - Ok(None) - } - } - - pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { - let exchange = - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?; - let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(tx, &mut session) - } - - pub fn purge(&mut self) { - let mut to_purge: LinearMap = LinearMap::new(); - - for (exch_id, exchange) in self.exchanges.iter() { - if exchange.is_purgeable() { - let _ = to_purge.insert(*exch_id, ()); - } - } - for (exch_id, _) in to_purge.iter() { - self.exchanges.remove(exch_id); - } - } - - pub fn pending_ack(&mut self) -> Option { - self.exchanges - .iter() - .find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch)) - .map(|(exch_id, _)| *exch_id) - } - - pub fn evict_session(&mut self, tx: &mut Packet) -> Result { - if let Some(index) = self.sess_mgr.get_session_for_eviction() { - info!("Sessions full, vacating session with index: {}", index); - // If we enter here, we have an LRU session that needs to be reclaimed - // As per the spec, we need to send a CLOSE here - - let mut session = self.sess_mgr.get_session_handle(index); - secure_channel::common::create_sc_status_report( - tx, - secure_channel::common::SCStatusCodes::CloseSession, - None, - )?; - - if let Some((_, exchange)) = - self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) - { - // Send Close_session on this exchange, and then close the session - // Should this be done for all exchanges? - error!("Sending Close Session"); - exchange.send(tx, &mut session)?; - // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. + self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; } - let remove_exchanges: heapless::Vec = self - .exchanges - .iter() - .filter_map(|(eid, e)| { - if e.sess_idx == index { - Some(*eid) - } else { - None - } - }) - .collect(); - info!( - "Terminating the following exchanges: {:?}", - remove_exchanges - ); - for exch_id in remove_exchanges { - // Remove from exchange list - self.exchanges.remove(&exch_id); - } + ctx.state = ExchangeState::ExchangeSend { + tx: tx as *const _, + rx: rx as *mut _, + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); - self.sess_mgr.remove(index); + Ok(()) + })?; - Ok(true) - } else { - Ok(false) - } + self.notification.wait().await; + + Ok(()) } - pub fn add_session(&mut self, clone_data: &CloneData) -> Result { - let sess_idx = self.sess_mgr.clone_session(clone_data)?; + pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> { + self.send_complete(tx).await + } - Ok(self.sess_mgr.get_session_handle(sess_idx)) + pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> { + let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) }; + + self.with_ctx_mut(|_self, ctx| { + if !matches!(ctx.state, ExchangeState::Active) { + Err(ErrorCode::NoExchange)?; + } + + ctx.state = ExchangeState::Complete { + tx: tx as *const _, + notification: &_self.notification as *const _, + }; + _self.transport.send_notification.signal(()); + + Ok(()) + })?; + + self.notification.wait().await; + + Ok(()) + } + + fn with_ctx(&self, f: F) -> Result + where + F: FnOnce(&Self, &ExchangeCtx) -> Result, + { + let mut exchanges = self.transport.exchanges.borrow_mut(); + + let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + + f(self, exchange) + } + + fn with_ctx_mut(&mut self, f: F) -> Result + where + F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result, + { + let mut exchanges = self.transport.exchanges.borrow_mut(); + + let exchange = Transport::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO + + f(self, exchange) } } -impl fmt::Display for ExchangeMgr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{{ Session Mgr: {},", self.sess_mgr)?; - writeln!(f, " Exchanges: [")?; - for s in &self.exchanges { - writeln!(f, "{{ {}, }},", s.1)?; - } - writeln!(f, " ]")?; - write!(f, "}}") - } -} - -#[cfg(test)] -#[allow(clippy::bool_assert_comparison)] -mod tests { - use crate::{ - error::ErrorCode, - transport::{ - network::Address, - session::{CloneData, SessionMode}, - }, - utils::{epoch::dummy_epoch, rand::dummy_rand}, - }; - - use super::{ExchangeMgr, Role}; - - #[test] - fn test_purge() { - let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap(); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap(); - - mgr.purge(); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(), - true - ); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(), - true - ); - - // Close e1 - let e1 = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).unwrap(); - e1.close(); - mgr.purge(); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(), - false - ); - assert_eq!( - ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(), - true - ); - } - - fn get_clone_data(peer_sess_id: u16, local_sess_id: u16) -> CloneData { - CloneData::new( - 12341234, - 43211234, - peer_sess_id, - local_sess_id, - Address::default(), - SessionMode::Pase, - ) - } - - fn fill_sessions(mgr: &mut ExchangeMgr, count: usize) { - let mut local_sess_id = 1; - let mut peer_sess_id = 100; - for _ in 1..count { - let clone_data = get_clone_data(peer_sess_id, local_sess_id); - match mgr.add_session(&clone_data) { - Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), - Err(e) => { - if e.code() == ErrorCode::NoSpace { - break; - } else { - panic!("Could not create sessions"); - } - } - } - local_sess_id += 1; - peer_sess_id += 1; - } - } - - #[cfg(feature = "std")] - #[test] - /// We purposefuly overflow the sessions - /// and when the overflow happens, we confirm that - /// - The sessions are evicted in LRU - /// - The exchanges associated with those sessions are evicted too - fn test_sess_evict() { - use crate::transport::packet::{Packet, MAX_TX_BUF_SIZE}; - use crate::transport::session::MAX_SESSIONS; - - let mut mgr = ExchangeMgr::new(crate::utils::epoch::sys_epoch, dummy_rand); - - fill_sessions(&mut mgr, MAX_SESSIONS + 1); - // Sessions are now full from local session id 1 to 16 - - // Create exchanges for sessions 2 (i.e. session index 1) and 3 (session index 2) - // Exchange IDs are 20 and 30 respectively - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 20, Role::Responder, true).unwrap(); - let _ = ExchangeMgr::_get(&mut mgr.exchanges, 2, 30, Role::Responder, true).unwrap(); - - // Confirm that session ids 1 to MAX_SESSIONS exists - for i in 1..(MAX_SESSIONS + 1) { - assert_eq!(mgr.sess_mgr.get_with_id(i as u16).is_none(), false); - } - // Confirm that the exchanges are around - assert_eq!(mgr.get_with_id(20).is_none(), false); - assert_eq!(mgr.get_with_id(30).is_none(), false); - let mut old_local_sess_id = 1; - let mut new_local_sess_id = 100; - let mut new_peer_sess_id = 200; - - for i in 1..(MAX_SESSIONS + 1) { - // Now purposefully overflow the sessions by adding another session - let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); - assert!(matches!( - result.map_err(|e| e.code()), - Err(ErrorCode::NoSpace) - )); - - let mut buf = [0; MAX_TX_BUF_SIZE]; - let tx = &mut Packet::new_tx(&mut buf); - let evicted = mgr.evict_session(tx).unwrap(); - assert!(evicted); - - let session = mgr - .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) - .unwrap(); - assert_eq!(session.get_peer_sess_id(), new_peer_sess_id); - - // This should have evicted session with local sess_id - assert_eq!(mgr.sess_mgr.get_with_id(old_local_sess_id).is_none(), true); - - new_local_sess_id += 1; - new_peer_sess_id += 1; - old_local_sess_id += 1; - - match i { - 1 => { - // Both exchanges should exist - assert_eq!(mgr.get_with_id(20).is_none(), false); - assert_eq!(mgr.get_with_id(30).is_none(), false); - } - 2 => { - // Exchange 20 would have been evicted - assert_eq!(mgr.get_with_id(20).is_none(), true); - assert_eq!(mgr.get_with_id(30).is_none(), false); - } - 3 => { - // Exchange 20 and 30 would have been evicted - assert_eq!(mgr.get_with_id(20).is_none(), true); - assert_eq!(mgr.get_with_id(30).is_none(), true); - } - _ => {} - } - } - // println!("Session mgr {}", mgr.sess_mgr); +impl<'a> Drop for Exchange<'a> { + fn drop(&mut self) { + let _ = self.with_ctx_mut(|_self, ctx| { + ctx.state = ExchangeState::Closed; + _self.transport.send_notification.signal(()); + + Ok(()) + }); } } diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index a219f16..6c5601e 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -23,7 +23,7 @@ pub mod network; pub mod packet; pub mod pipe; pub mod plain_hdr; -pub mod proto_ctx; pub mod proto_hdr; +pub mod runner; pub mod session; pub mod udp; diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs deleted file mode 100644 index b7374ec..0000000 --- a/matter/src/transport/proto_ctx.rs +++ /dev/null @@ -1,41 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::error::Error; - -use super::exchange::ExchangeCtx; -use super::packet::Packet; - -/// This is the context in which a receive packet is being processed -pub struct ProtoCtx<'a, 'b> { - /// This is the exchange context, that includes the exchange and the session - pub exch_ctx: ExchangeCtx<'a>, - /// This is the received buffer for this transaction - pub rx: &'a Packet<'b>, - /// This is the transmit buffer for this transaction - pub tx: &'a mut Packet<'b>, -} - -impl<'a, 'b> ProtoCtx<'a, 'b> { - pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self { - Self { exch_ctx, rx, tx } - } - - pub fn send(&mut self) -> Result { - self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess) - } -} diff --git a/matter/src/transport/runner.rs b/matter/src/transport/runner.rs new file mode 100644 index 0000000..f94e819 --- /dev/null +++ b/matter/src/transport/runner.rs @@ -0,0 +1,392 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::{mem::MaybeUninit, pin::pin}; + +use crate::{ + alloc, + data_model::{core::DataModel, objects::DataModelHandler}, + interaction_model::core::PROTO_ID_INTERACTION_MODEL, + transport::network::{Address, IpAddr, Ipv6Addr, SocketAddr}, + CommissioningData, Matter, +}; +use embassy_futures::select::{select, select3, select_slice, Either}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; +use log::{error, info, warn}; + +use crate::{ + error::Error, + secure_channel::{common::PROTO_ID_SECURE_CHANNEL, core::SecureChannel}, + transport::packet::{Packet, MAX_RX_BUF_SIZE}, + utils::select::EitherUnwrap, +}; + +use super::{ + core::Transport, + exchange::{ExchangeCtr, Notification, MAX_EXCHANGES}, + packet::{MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE}, + pipe::{Chunk, Pipe}, + udp::UdpListener, +}; + +pub type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +pub type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; +type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>; + +struct PacketPools { + tx: [TxBuf; MAX_EXCHANGES], + rx: [RxBuf; MAX_EXCHANGES], + sx: [SxBuf; MAX_EXCHANGES], +} + +impl PacketPools { + const TX_ELEM: TxBuf = MaybeUninit::uninit(); + const RX_ELEM: RxBuf = MaybeUninit::uninit(); + const SX_ELEM: SxBuf = MaybeUninit::uninit(); + + const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES]; + const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES]; + const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES]; + + #[inline(always)] + pub const fn new() -> Self { + Self { + tx: Self::TX_INIT, + rx: Self::RX_INIT, + sx: Self::SX_INIT, + } + } +} + +/// This struct implements an executor-agnostic option to run the Matter transport stack end-to-end. +/// +/// Since it is not possible to use executor tasks spawning in an executor-agnostic way (yet), +/// the async loops are arranged as one giant future. Therefore, the cost is a slightly slower execution +/// due to the generated future being relatively big and deeply nested. +/// +/// Users are free to implement their own async execution loop, by utilizing the `Transport` +/// struct directly with their async executor of choice. +pub struct TransportRunner<'a> { + transport: Transport<'a>, + pools: PacketPools, +} + +impl<'a> TransportRunner<'a> { + #[inline(always)] + pub fn new(matter: &'a Matter<'a>) -> Self { + Self::wrap(Transport::new(matter)) + } + + #[inline(always)] + pub const fn wrap(transport: Transport<'a>) -> Self { + Self { + transport, + pools: PacketPools::new(), + } + } + + pub fn transport(&self) -> &Transport { + &self.transport + } + + pub async fn run_udp( + &mut self, + tx_buf: &mut TxBuf, + rx_buf: &mut RxBuf, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + let udp = UdpListener::new(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + self.transport.matter().port, + )) + .await?; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe, dev_comm, handler).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run( + &mut self, + tx_pipe: &Pipe<'_>, + rx_pipe: &Pipe<'_>, + dev_comm: CommissioningData, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Running Matter transport"); + + let buf = unsafe { self.pools.rx[0].assume_init_mut() }; + + if self.transport.matter().start_comissioning(dev_comm, buf)? { + info!("Comissioning started"); + } + + let construction_notification = Notification::new(); + + let mut rx = pin!(Self::handle_rx( + &self.transport, + &mut self.pools, + rx_pipe, + &construction_notification, + handler + )); + let mut tx = pin!(Self::handle_tx(&self.transport, tx_pipe)); + + select(&mut rx, &mut tx).await.unwrap() + } + + async fn handle_rx( + transport: &Transport<'_>, + pools: &mut PacketPools, + rx_pipe: &Pipe<'_>, + construction_notification: &Notification, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + info!("Creating queue for {} exchanges", 1); + + let channel = Channel::::new(); + + info!("Creating {} handlers", MAX_EXCHANGES); + let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new(); + + info!("Handlers size: {}", core::mem::size_of_val(&handlers)); + + let pools = &mut *pools as *mut _; + + for index in 0..MAX_EXCHANGES { + let channel = &channel; + let handler_id = index; + + handlers + .push(async move { + loop { + let exchange_ctr: ExchangeCtr<'_> = channel.recv().await; + + info!( + "Handler {}: Got exchange {:?}", + handler_id, + exchange_ctr.id() + ); + + let result = Self::handle_exchange( + transport, + pools, + handler_id, + exchange_ctr, + handler, + ) + .await; + + if let Err(err) = result { + warn!( + "Handler {}: Exchange closed because of error: {:?}", + handler_id, err + ); + } else { + info!("Handler {}: Exchange completed", handler_id); + } + } + }) + .map_err(|_| ()) + .unwrap(); + } + + let mut rx = pin!(async { + loop { + info!("Transport: waiting for incoming packets"); + + { + let mut data = rx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end])); + rx.peer = chunk.addr; + + if let Some(exchange_ctr) = + transport.process_rx(construction_notification, &mut rx)? + { + let exchange_id = exchange_ctr.id().clone(); + + info!("Transport: got new exchange: {:?}", exchange_id); + + channel.send(exchange_ctr).await; + info!("Transport: exchange sent"); + + transport + .wait_construction(construction_notification, &rx, &exchange_id) + .await?; + + info!("Transport: exchange started"); + } + + data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); + } + } + + rx_pipe.data_supplied_notification.wait().await + } + + #[allow(unreachable_code)] + Ok::<_, Error>(()) + }); + + let result = select(&mut rx, select_slice(&mut handlers)).await; + + if let Either::First(result) = result { + if let Err(e) = &result { + error!("Exitting RX loop due to an error: {:?}", e); + } + + result?; + } + + Ok(()) + } + + async fn handle_tx(transport: &Transport<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let mut tx = alloc!(Packet::new_tx(data.buf)); + + if transport.pull_tx(&mut tx).await? { + data.chunk = Some(Chunk { + start: tx.get_writebuf()?.get_start(), + end: tx.get_writebuf()?.get_tail(), + addr: tx.peer, + }); + tx_pipe.data_supplied_notification.signal(()); + } else { + break; + } + } + } + + tx_pipe.data_consumed_notification.wait().await; + } + + transport.wait_tx().await?; + } + } + + #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex + async fn handle_exchange( + transport: &Transport<'_>, + pools: *mut PacketPools, + handler_id: usize, + exchange_ctr: ExchangeCtr<'_>, + handler: &H, + ) -> Result<(), Error> + where + H: DataModelHandler, + { + let pools = unsafe { pools.as_mut() }.unwrap(); + + let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() }; + let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() }; + let rx_status_buf = unsafe { pools.sx[handler_id].assume_init_mut() }; + + let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut())); + let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut())); + + let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?); + + match rx.get_proto_id() { + PROTO_ID_SECURE_CHANNEL => { + let sc = SecureChannel::new(transport.matter()); + + sc.handle(&mut exchange, &mut rx, &mut tx).await?; + + transport.notify_changed(); + } + PROTO_ID_INTERACTION_MODEL => { + let dm = DataModel::new(handler); + + let mut rx_status = alloc!(Packet::new_rx(rx_status_buf)); + + dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status) + .await?; + + transport.notify_changed(); + } + other => { + error!("Unknown Proto-ID: {}", other); + } + } + + Ok(()) + } +} diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index 5e43e18..dd83f0a 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -15,10 +15,9 @@ * limitations under the License. */ -use std::{ - convert::TryInto, - sync::{Arc, Mutex, Once}, -}; +use core::cell::Cell; +use core::convert::TryInto; +use std::sync::{Arc, Mutex, Once}; use matter::{ attribute_enum, command_enum, @@ -28,11 +27,9 @@ use matter::{ Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, error::{Error, ErrorCode}, - interaction_model::{ - core::Transaction, - messages::ib::{attr_list_write, ListOperation}, - }, + interaction_model::messages::ib::{attr_list_write, ListOperation}, tlv::{TLVElement, TagType}, + transport::exchange::Exchange, utils::rand::Rand, }; use num_derive::FromPrimitive; @@ -132,10 +129,10 @@ pub const WRITE_LIST_MAX: usize = 5; pub struct EchoCluster { pub data_ver: Dataver, pub multiplier: u8, - pub att1: u16, - pub att2: u16, - pub att_write: u16, - pub att_custom: u32, + pub att1: Cell, + pub att2: Cell, + pub att_write: Cell, + pub att_custom: Cell, } impl EchoCluster { @@ -143,10 +140,10 @@ impl EchoCluster { Self { data_ver: Dataver::new(rand), multiplier, - att1: 0x1234, - att2: 0x5678, - att_write: ATTR_WRITE_DEFAULT_VALUE, - att_custom: ATTR_CUSTOM_VALUE, + att1: Cell::new(0x1234), + att2: Cell::new(0x5678), + att_write: Cell::new(ATTR_WRITE_DEFAULT_VALUE), + att_custom: Cell::new(ATTR_CUSTOM_VALUE), } } @@ -179,14 +176,14 @@ impl EchoCluster { } } - pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { let data = data.with_dataver(self.data_ver.get())?; match attr.attr_id.try_into()? { - Attributes::Att1(codec) => self.att1 = codec.decode(data)?, - Attributes::Att2(codec) => self.att2 = codec.decode(data)?, - Attributes::AttWrite(codec) => self.att_write = codec.decode(data)?, - Attributes::AttCustom(codec) => self.att_custom = codec.decode(data)?, + Attributes::Att1(codec) => self.att1.set(codec.decode(data)?), + Attributes::Att2(codec) => self.att2.set(codec.decode(data)?), + Attributes::AttWrite(codec) => self.att_write.set(codec.decode(data)?), + Attributes::AttCustom(codec) => self.att_custom.set(codec.decode(data)?), Attributes::AttWriteList(_) => { attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))? } @@ -198,8 +195,8 @@ impl EchoCluster { } pub fn invoke( - &mut self, - _transaction: &mut Transaction, + &self, + _exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, @@ -222,7 +219,7 @@ impl EchoCluster { } } - fn write_attr_list(&mut self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { + fn write_attr_list(&self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { let tc_handle = TestChecker::get().unwrap(); let mut tc = tc_handle.lock().unwrap(); match op { @@ -272,18 +269,18 @@ impl Handler for EchoCluster { EchoCluster::read(self, attr, encoder) } - fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { EchoCluster::write(self, attr, data) } fn invoke( - &mut self, - transaction: &mut Transaction, + &self, + exchange: &Exchange, cmd: &CmdDetails, data: &TLVElement, encoder: CmdDataEncoder, ) -> Result<(), Error> { - EchoCluster::invoke(self, transaction, cmd, data, encoder) + EchoCluster::invoke(self, exchange, cmd, data, encoder) } } diff --git a/matter/tests/common/handlers.rs b/matter/tests/common/handlers.rs index 7235b8a..97de89a 100644 --- a/matter/tests/common/handlers.rs +++ b/matter/tests/common/handlers.rs @@ -1,8 +1,6 @@ -use core::time; -use std::thread; - use log::{info, warn}; use matter::{ + error::ErrorCode, interaction_model::{ core::{IMStatusCode, OpCode}, messages::{ @@ -14,17 +12,12 @@ use matter::{ }, }, tlv::{self, FromTLV, TLVArray, ToTLV}, - transport::{ - exchange::{self, Exchange}, - session::NocCatIds, - }, - Matter, }; use super::{ attributes::assert_attr_report, commands::{assert_inv_response, ExpectedInvResp}, - im_engine::{ImEngine, ImInput, IM_ENGINE_PEER_ID}, + im_engine::{ImEngine, ImEngineHandler, ImInput, ImOutput}, }; pub enum WriteResponse<'a> { @@ -38,72 +31,71 @@ pub enum TimedInvResponse<'a> { } impl<'a> ImEngine<'a> { + pub fn read_reqs(input: &[AttrPath], expected: &[AttrResp]) { + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_read_reqs(&im.handler(), input, expected); + } + // Helper for handling Read Req sequences for this file pub fn handle_read_reqs( - &mut self, - peer_node_id: u64, + &self, + handler: &ImEngineHandler, input: &[AttrPath], expected: &[AttrResp], ) { - let mut out_buf = [0u8; 400]; - let received = self.gen_read_reqs_output(peer_node_id, input, None, &mut out_buf); + let mut out = heapless::Vec::<_, 1>::new(); + let received = self.gen_read_reqs_output(handler, input, None, &mut out); assert_attr_report(&received, expected) } - pub fn new_with_read_reqs( - matter: &'a Matter<'a>, + pub fn gen_read_reqs_output<'c, const N: usize>( + &self, + handler: &ImEngineHandler, input: &[AttrPath], - expected: &[AttrResp], - ) -> Self { - let mut im = Self::new(matter); - - let mut out_buf = [0u8; 400]; - let received = im.gen_read_reqs_output(IM_ENGINE_PEER_ID, input, None, &mut out_buf); - assert_attr_report(&received, expected); - - im - } - - pub fn gen_read_reqs_output<'b>( - &mut self, - peer_node_id: u64, - input: &[AttrPath], - dataver_filters: Option>, - out_buf: &'b mut [u8], - ) -> ReportDataMsg<'b> { + dataver_filters: Option>, + out: &'c mut heapless::Vec, + ) -> ReportDataMsg<'c> { let mut read_req = ReadReq::new(true).set_attr_requests(input); read_req.dataver_filters = dataver_filters; - let mut input = ImInput::new(OpCode::ReadRequest, &read_req); - input.set_peer_node_id(peer_node_id); + let input = ImInput::new(OpCode::ReadRequest, &read_req); - let (_, out_buf) = self.process(&input, out_buf); + self.process(handler, &[&input], out).unwrap(); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + for o in &*out { + tlv::print_tlv_list(&o.data); + } + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); ReportDataMsg::from_tlv(&root).unwrap() } + pub fn write_reqs(input: &[AttrData], expected: &[AttrStatus]) { + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_write_reqs(&im.handler(), input, expected); + } + pub fn handle_write_reqs( - &mut self, - peer_node_id: u64, - peer_cat_ids: Option<&NocCatIds>, + &self, + handler: &ImEngineHandler, input: &[AttrData], expected: &[AttrStatus], ) { - let mut out_buf = [0u8; 400]; let write_req = WriteReq::new(false, input); - let mut input = ImInput::new(OpCode::WriteRequest, &write_req); - input.set_peer_node_id(peer_node_id); - if let Some(cat_ids) = peer_cat_ids { - input.set_cat_ids(cat_ids); + let input = ImInput::new(OpCode::WriteRequest, &write_req); + let mut out = heapless::Vec::<_, 1>::new(); + self.process(handler, &[&input], &mut out).unwrap(); + + for o in &out { + tlv::print_tlv_list(&o.data); } - let (_, out_buf) = self.process(&input, &mut out_buf); - - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); let mut index = 0; let response_iter = root @@ -124,194 +116,184 @@ impl<'a> ImEngine<'a> { assert_eq!(index, expected.len()); } - pub fn new_with_write_reqs( - matter: &'a Matter<'a>, - input: &[AttrData], - expected: &[AttrStatus], - ) -> Self { - let mut im = Self::new(matter); + pub fn commands(input: &[CmdData], expected: &[ExpectedInvResp]) { + let im = ImEngine::new_default(); - im.handle_write_reqs(IM_ENGINE_PEER_ID, None, input, expected); - - im + im.add_default_acl(); + im.handle_commands(&im.handler(), input, expected) } // Helper for handling Invoke Command sequences pub fn handle_commands( - &mut self, - peer_node_id: u64, + &self, + handler: &ImEngineHandler, input: &[CmdData], expected: &[ExpectedInvResp], ) { - let mut out_buf = [0u8; 400]; let req = InvReq { suppress_response: Some(false), timed_request: Some(false), inv_requests: Some(TLVArray::Slice(input)), }; - let mut input = ImInput::new(OpCode::InvokeRequest, &req); - input.set_peer_node_id(peer_node_id); + let input = ImInput::new(OpCode::InvokeRequest, &req); - let (_, out_buf) = self.process(&input, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + let mut out = heapless::Vec::<_, 1>::new(); + self.process(handler, &[&input], &mut out).unwrap(); + + for o in &out { + tlv::print_tlv_list(&o.data); + } + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); let resp = msg::InvResp::from_tlv(&root).unwrap(); assert_inv_response(&resp, expected) } - pub fn new_with_commands( - matter: &'a Matter<'a>, - input: &[CmdData], - expected: &[ExpectedInvResp], - ) -> Self { - let mut im = ImEngine::new(matter); - - im.handle_commands(IM_ENGINE_PEER_ID, input, expected); - - im - } - - fn handle_timed_reqs<'b>( - &mut self, + fn gen_timed_reqs_output( + &self, + handler: &ImEngineHandler, opcode: OpCode, request: &dyn ToTLV, timeout: u16, delay: u16, - output: &'b mut [u8], - ) -> (u8, &'b [u8]) { - // Use the same exchange for all parts of the transaction - self.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); + out: &mut heapless::Vec, + ) { + let mut inp = heapless::Vec::<_, 2>::new(); + + let timed_req = TimedReq { timeout }; + let im_input = ImInput::new_delayed(OpCode::TimedRequest, &timed_req, Some(delay)); if timeout != 0 { // Send Timed Req - let mut tmp_buf = [0u8; 400]; - let timed_req = TimedReq { timeout }; - let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); - let (_, out_buf) = self.process(&im_input, &mut tmp_buf); - tlv::print_tlv_list(out_buf); + inp.push(&im_input).map_err(|_| ErrorCode::NoSpace).unwrap(); } else { warn!("Skipping timed request"); } - // Process any delays - let delay = time::Duration::from_millis(delay.into()); - thread::sleep(delay); - // Send Write Req let input = ImInput::new(opcode, request); - let (resp_opcode, output) = self.process(&input, output); - (resp_opcode, output) + inp.push(&input).map_err(|_| ErrorCode::NoSpace).unwrap(); + + self.process(handler, &inp, out).unwrap(); + + drop(inp); + + for o in out { + tlv::print_tlv_list(&o.data); + } } - // Helper for handling Write Attribute sequences - pub fn handle_timed_write_reqs( - &mut self, + pub fn timed_write_reqs( input: &[AttrData], expected: &WriteResponse, timeout: u16, delay: u16, ) { - let mut out_buf = [0u8; 400]; + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_timed_write_reqs(&im.handler(), input, expected, timeout, delay); + } + + // Helper for handling Write Attribute sequences + pub fn handle_timed_write_reqs( + &self, + handler: &ImEngineHandler, + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) { + let mut out = heapless::Vec::<_, 2>::new(); let write_req = WriteReq::new(false, input); - let (resp_opcode, out_buf) = self.handle_timed_reqs( + self.gen_timed_reqs_output( + handler, OpCode::WriteRequest, &write_req, timeout, delay, - &mut out_buf, + &mut out, ); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + + let out = &out[out.len() - 1]; + let root = tlv::get_root_node_struct(&out.data).unwrap(); match expected { WriteResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::WriteResponse) - ); + assert_eq!(out.action, OpCode::WriteResponse); let resp = WriteResp::from_tlv(&root).unwrap(); assert_eq!(resp.write_responses, t); } WriteResponse::TransactionError => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); + assert_eq!(out.action, OpCode::StatusResponse); let status_resp = StatusResp::from_tlv(&root).unwrap(); assert_eq!(status_resp.status, IMStatusCode::Timeout); } } } - pub fn new_with_timed_write_reqs( - matter: &'a Matter<'a>, - input: &[AttrData], - expected: &WriteResponse, - timeout: u16, - delay: u16, - ) -> Self { - let mut im = ImEngine::new(matter); - - im.handle_timed_write_reqs(input, expected, timeout, delay); - - im - } - - // Helper for handling Invoke Command sequences - pub fn handle_timed_commands( - &mut self, + pub fn timed_commands( input: &[CmdData], expected: &TimedInvResponse, timeout: u16, delay: u16, set_timed_request: bool, ) { - let mut out_buf = [0u8; 400]; + let im = ImEngine::new_default(); + + im.add_default_acl(); + im.handle_timed_commands( + &im.handler(), + input, + expected, + timeout, + delay, + set_timed_request, + ); + } + + // Helper for handling Invoke Command sequences + pub fn handle_timed_commands( + &self, + handler: &ImEngineHandler, + input: &[CmdData], + expected: &TimedInvResponse, + timeout: u16, + delay: u16, + set_timed_request: bool, + ) { + let mut out = heapless::Vec::<_, 2>::new(); let req = InvReq { suppress_response: Some(false), timed_request: Some(set_timed_request), inv_requests: Some(TLVArray::Slice(input)), }; - let (resp_opcode, out_buf) = - self.handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); + self.gen_timed_reqs_output( + handler, + OpCode::InvokeRequest, + &req, + timeout, + delay, + &mut out, + ); + + let out = &out[out.len() - 1]; + let root = tlv::get_root_node_struct(&out.data).unwrap(); match expected { TimedInvResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::InvokeResponse) - ); + assert_eq!(out.action, OpCode::InvokeResponse); let resp = msg::InvResp::from_tlv(&root).unwrap(); assert_inv_response(&resp, t) } TimedInvResponse::TransactionError(e) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); + assert_eq!(out.action, OpCode::StatusResponse); let status_resp = StatusResp::from_tlv(&root).unwrap(); assert_eq!(status_resp.status, *e); } } } - - pub fn new_with_timed_commands( - matter: &'a Matter<'a>, - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, - ) -> Self { - let mut im = ImEngine::new(matter); - - im.handle_timed_commands(input, expected, timeout, delay, set_timed_request); - - im - } } diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 166b7fc..8efb2c9 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -17,14 +17,19 @@ use crate::common::echo_cluster; use core::borrow::Borrow; +use core::future::pending; +use core::time::Duration; +use embassy_futures::select::select3; use matter::{ acl::{AclEntry, AuthMode}, data_model::{ cluster_basic_information::{self, BasicInfoConfig}, cluster_on_off::{self, OnOffCluster}, - core::DataModel, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, - objects::{Endpoint, Node, Privilege}, + objects::{ + AttrData, AttrDataEncoder, AttrDetails, Endpoint, Handler, HandlerCompat, Metadata, + Node, NonBlockingHandler, Privilege, + }, root_endpoint::{self, RootEndpointHandler}, sdm::{ admin_commissioning, @@ -36,21 +41,24 @@ use matter::{ descriptor::{self, DescriptorCluster}, }, }, - error::Error, + error::{Error, ErrorCode}, handler_chain_type, - interaction_model::core::{InteractionModel, OpCode}, - mdns::Mdns, + interaction_model::core::{OpCode, PROTO_ID_INTERACTION_MODEL}, + mdns::DummyMdns, + secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL, spake2p::VerifierData}, tlv::{TLVWriter, TagType, ToTLV}, - transport::packet::Packet, transport::{ - exchange::{self, Exchange, ExchangeCtx}, - network::{Address, IpAddr, Ipv4Addr, SocketAddr}, - packet::MAX_RX_BUF_SIZE, - proto_ctx::ProtoCtx, - session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, + exchange::Notification, + packet::{Packet, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}, + pipe::Pipe, + runner::TransportRunner, }, - utils::{rand::dummy_rand, writebuf::WriteBuf}, - Matter, + transport::{ + network::Address, + session::{CaseDetails, CloneData, NocCatIds, SessionMode}, + }, + utils::select::EitherUnwrap, + CommissioningData, Matter, MATTER_PORT, }; use super::echo_cluster::EchoCluster; @@ -74,183 +82,321 @@ impl DevAttDataFetcher for DummyDevAtt { } pub const IM_ENGINE_PEER_ID: u64 = 445566; +pub const IM_ENGINE_REMOTE_PEER_ID: u64 = 123456; + +const NODE: Node<'static> = Node { + id: 0, + endpoints: &[ + Endpoint { + id: 0, + clusters: &[ + descriptor::CLUSTER, + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ROOT_NODE, + }, + Endpoint { + id: 1, + clusters: &[ + descriptor::CLUSTER, + cluster_on_off::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ON_OFF_LIGHT, + }, + ], +}; pub struct ImInput<'a> { action: OpCode, data: &'a dyn ToTLV, - peer_id: u64, - cat_ids: NocCatIds, + delay: Option, } impl<'a> ImInput<'a> { pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self { + Self::new_delayed(action, data, None) + } + + pub fn new_delayed(action: OpCode, data: &'a dyn ToTLV, delay: Option) -> Self { Self { action, data, - peer_id: IM_ENGINE_PEER_ID, - cat_ids: Default::default(), + delay, } } - - pub fn set_peer_node_id(&mut self, peer: u64) { - self.peer_id = peer; - } - - pub fn set_cat_ids(&mut self, cat_ids: &NocCatIds) { - self.cat_ids = *cat_ids; - } } -pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'a>, EchoCluster | RootEndpointHandler<'a>); - -pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { - #[cfg(feature = "std")] - use matter::utils::epoch::sys_epoch as epoch; - - #[cfg(not(feature = "std"))] - use matter::utils::epoch::dummy_epoch as epoch; - - Matter::new(&BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) +pub struct ImOutput { + pub action: OpCode, + pub data: heapless::Vec, } -/// An Interaction Model Engine to facilitate easy testing -pub struct ImEngine<'a> { - pub matter: &'a Matter<'a>, - pub im: InteractionModel>>, - // By default, a new exchange is created for every run, if you wish to instead using a specific - // exchange, set this variable. This is helpful in situations where you have to run multiple - // actions in the same transaction (exchange) - pub exch: Option, +pub struct ImEngineHandler<'a> { + handler: handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'static>, EchoCluster | RootEndpointHandler<'a>), } -impl<'a> ImEngine<'a> { - /// Create the interaction model engine +impl<'a> ImEngineHandler<'a> { pub fn new(matter: &'a Matter<'a>) -> Self { - let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - // Only allow the standard peer node id of the IM Engine - default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); - matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + let handler = root_endpoint::handler(0, matter) + .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) + .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) + .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())); - let dm = DataModel::new( - matter.borrow(), - &Node { - id: 0, - endpoints: &[ - Endpoint { - id: 0, - clusters: &[ - descriptor::CLUSTER, - cluster_basic_information::CLUSTER, - general_commissioning::CLUSTER, - nw_commissioning::CLUSTER, - admin_commissioning::CLUSTER, - noc::CLUSTER, - access_control::CLUSTER, - echo_cluster::CLUSTER, - ], - device_type: DEV_TYPE_ROOT_NODE, - }, - Endpoint { - id: 1, - clusters: &[ - descriptor::CLUSTER, - cluster_on_off::CLUSTER, - echo_cluster::CLUSTER, - ], - device_type: DEV_TYPE_ON_OFF_LIGHT, - }, - ], - }, - root_endpoint::handler(0, matter) - .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) - .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) - .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) - .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), - ); - - Self { - matter, - im: InteractionModel(dm), - exch: None, - } + Self { handler } } pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { match endpoint { - 0 => &self.im.0.handler.next.next.next.handler, - 1 => &self.im.0.handler.next.handler, + 0 => &self.handler.next.next.next.handler, + 1 => &self.handler.next.handler, _ => panic!(), } } +} - /// Run a transaction through the interaction model engine - pub fn process<'b>(&mut self, input: &ImInput, data_out: &'b mut [u8]) -> (u8, &'b [u8]) { - let mut new_exch = Exchange::new(1, 0, exchange::Role::Responder); - // Choose whether to use a new exchange, or use the one from the ImEngine configuration - let exch = self.exch.as_mut().unwrap_or(&mut new_exch); +impl<'a> Handler for ImEngineHandler<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + self.handler.read(attr, encoder) + } - let mut sess_mgr = SessionMgr::new(*self.matter.borrow(), *self.matter.borrow()); + fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + self.handler.write(attr, data) + } - let clone_data = CloneData::new( - 123456, - input.peer_id, - 10, - 30, - Address::Udp(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 5542, - )), - SessionMode::Case(CaseDetails::new(1, &input.cat_ids)), - ); - let sess_idx = sess_mgr.clone_session(&clone_data).unwrap(); - let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { - exch, - sess, - epoch: *self.matter.borrow(), - }; - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; 1440]; // For the long read tests to run unchanged - let mut rx = Packet::new_rx(&mut rx_buf); - let mut tx = Packet::new_tx(&mut tx_buf); - // Create fake rx packet - rx.set_proto_id(0x01); - rx.set_proto_opcode(input.action as u8); - rx.peer = Address::default(); - - { - let mut buf = [0u8; 400]; - let mut wb = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut wb); - - input.data.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - - let input_data = wb.as_slice(); - let in_data_len = input_data.len(); - let rx_buf = rx.as_mut_slice(); - rx_buf[..in_data_len].copy_from_slice(input_data); - rx.get_parsebuf().unwrap().set_len(in_data_len); - } - - let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - self.im.handle(&mut ctx).unwrap(); - let out_data_len = ctx.tx.as_slice().len(); - data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); - let response = ctx.tx.get_proto_raw_opcode(); - (response, &data_out[..out_data_len]) + fn invoke( + &self, + exchange: &matter::transport::exchange::Exchange, + cmd: &matter::data_model::objects::CmdDetails, + data: &matter::tlv::TLVElement, + encoder: matter::data_model::objects::CmdDataEncoder, + ) -> Result<(), Error> { + self.handler.invoke(exchange, cmd, data, encoder) } } -// TODO - Remove? -// // Create an Interaction Model, Data Model and run a rx/tx transaction through it -// pub fn im_engine<'a>( -// matter: &'a Matter, -// action: OpCode, -// data: &dyn ToTLV, -// data_out: &'a mut [u8], -// ) -> (DmHandler<'a>, u8, &'a mut [u8]) { -// let mut engine = ImEngine::new(matter); -// let input = ImInput::new(action, data); -// let (response, output) = engine.process(&input, data_out); -// (engine.dm.handler, response, output) -// } +impl<'a> NonBlockingHandler for ImEngineHandler<'a> {} + +impl<'a> Metadata for ImEngineHandler<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + fn lock(&self) -> Self::MetadataGuard<'_> { + NODE + } +} + +static mut DNS: DummyMdns = DummyMdns; + +/// An Interaction Model Engine to facilitate easy testing +pub struct ImEngine<'a> { + pub matter: Matter<'a>, + cat_ids: NocCatIds, +} + +impl<'a> ImEngine<'a> { + pub fn new_default() -> Self { + Self::new(Default::default()) + } + + /// Create the interaction model engine + pub fn new(cat_ids: NocCatIds) -> Self { + #[cfg(feature = "std")] + use matter::utils::epoch::sys_epoch as epoch; + + #[cfg(not(feature = "std"))] + use matter::utils::epoch::dummy_epoch as epoch; + + #[cfg(feature = "std")] + use matter::utils::rand::sys_rand as rand; + + #[cfg(not(feature = "std"))] + use matter::utils::rand::dummy_rand as rand; + + let matter = Matter::new( + &BASIC_INFO, + &DummyDevAtt, + unsafe { &mut DNS }, + epoch, + rand, + MATTER_PORT, + ); + + Self { matter, cat_ids } + } + + pub fn add_default_acl(&self) { + // Only allow the standard peer node id of the IM Engine + let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); + self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + } + + pub fn handler(&self) -> ImEngineHandler<'_> { + ImEngineHandler::new(&self.matter) + } + + pub fn process( + &self, + handler: &ImEngineHandler, + input: &[&ImInput], + out: &mut heapless::Vec, + ) -> Result<(), Error> { + let mut runner = TransportRunner::new(&self.matter); + + let clone_data = CloneData::new( + IM_ENGINE_REMOTE_PEER_ID, + IM_ENGINE_PEER_ID, + 1, + 1, + Address::default(), + SessionMode::Case(CaseDetails::new(1, &self.cat_ids)), + ); + + let sess_idx = runner + .transport() + .session_mgr + .borrow_mut() + .clone_session(&clone_data) + .unwrap(); + + let mut tx_pipe_buf = [0; MAX_RX_BUF_SIZE]; + let mut rx_pipe_buf = [0; MAX_TX_BUF_SIZE]; + + let mut tx_buf = [0; MAX_RX_BUF_SIZE]; + let mut rx_buf = [0; MAX_TX_BUF_SIZE]; + + let tx_pipe = Pipe::new(&mut tx_buf); + let rx_pipe = Pipe::new(&mut rx_buf); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let tx_pipe_buf = &mut tx_pipe_buf; + let rx_pipe_buf = &mut rx_pipe_buf; + + let handler = &handler; + let runner = &mut runner; + + let mut msg_ctr = runner + .transport() + .session_mgr + .borrow_mut() + .mut_by_index(sess_idx) + .unwrap() + .get_msg_ctr(); + + let resp_notif = Notification::new(); + let resp_notif = &resp_notif; + + embassy_futures::block_on(async move { + select3( + runner.run( + tx_pipe, + rx_pipe, + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *self.matter.borrow()), + discriminator: 250, + }, + &HandlerCompat(handler), + ), + async move { + let mut acknowledge = false; + for ip in input { + Self::send(ip, tx_pipe_buf, rx_pipe, msg_ctr, acknowledge).await?; + resp_notif.wait().await; + + if let Some(delay) = ip.delay { + if delay > 0 { + #[cfg(feature = "std")] + std::thread::sleep(Duration::from_millis(delay as _)); + } + } + + msg_ctr += 2; + acknowledge = true; + } + + pending::<()>().await; + + Ok(()) + }, + async move { + out.clear(); + + while out.len() < input.len() { + let (len, _) = tx_pipe.recv(rx_pipe_buf).await; + + let mut rx = Packet::new_rx(&mut rx_pipe_buf[..len]); + + rx.plain_hdr_decode()?; + rx.proto_decode(IM_ENGINE_REMOTE_PEER_ID, Some(&[0u8; 16]))?; + + if rx.get_proto_id() != PROTO_ID_SECURE_CHANNEL + || rx.get_proto_opcode::()? + != secure_channel::common::OpCode::MRPStandAloneAck + { + out.push(ImOutput { + action: rx.get_proto_opcode()?, + data: heapless::Vec::from_slice(rx.as_slice()) + .map_err(|_| ErrorCode::NoSpace)?, + }) + .map_err(|_| ErrorCode::NoSpace)?; + + resp_notif.signal(()); + } + } + + Ok(()) + }, + ) + .await + .unwrap() + })?; + + Ok(()) + } + + async fn send( + input: &ImInput<'_>, + tx_buf: &mut [u8], + rx_pipe: &Pipe<'_>, + msg_ctr: u32, + acknowledge: bool, + ) -> Result<(), Error> { + let mut tx = Packet::new_tx(tx_buf); + + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); + tx.set_proto_opcode(input.action as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + input.data.to_tlv(&mut tw, TagType::Anonymous)?; + + tx.plain.ctr = msg_ctr + 1; + tx.plain.sess_id = 1; + tx.proto.set_initiator(); + + if acknowledge { + tx.proto.set_ack(msg_ctr - 1); + } + + tx.proto_encode( + Address::default(), + Some(IM_ENGINE_REMOTE_PEER_ID), + IM_ENGINE_PEER_ID, + false, + Some(&[0u8; 16]), + )?; + + rx_pipe.send(Address::default(), tx.as_slice()).await; + + Ok(()) + } +} diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index ebb831c..853e2ca 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -26,7 +26,6 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, messages::GenericPath, }, - mdns::DummyMdns, tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, }; @@ -35,7 +34,7 @@ use crate::{ common::{ attributes::*, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, - im_engine::{matter, ImEngine}, + im_engine::{ImEngine, IM_ENGINE_PEER_ID}, init_env_logger, }, }; @@ -62,30 +61,28 @@ fn wc_read_attribute() { Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test1: Empty Response as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to only access endpoint 0 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to also access endpoint 1 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -95,7 +92,7 @@ fn wc_read_attribute() { attr_data_path!(ep0_att1, ElementType::U16(0x1234)), attr_data_path!(ep1_att1, ElementType::U16(0x1234)), ]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); } #[test] @@ -115,25 +112,23 @@ fn exact_read_attribute() { Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test1: Unsupported Access error as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - im.handle_read_reqs(peer, input, expected); + im.handle_read_reqs(&handler, input, expected); } #[test] @@ -177,52 +172,54 @@ fn wc_write_attribute() { EncodeValue::Closure(&attr_data1), )]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test 1: Wildcard write to an attribute without permission should return // no error - im.handle_write_reqs(peer, None, input0, &[]); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input0, &[]); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access one endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 2: Wildcard write to attributes will only return attributes // where the writes were successful im.handle_write_reqs( - peer, - None, + &handler, input0, &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)], ); - assert_eq!(val0, im.echo_cluster(0).att_write); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(1).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(1).att_write.get() + ); // Add ACL to allow our peer to access another endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 3: Wildcard write to attributes will return multiple attributes // where the writes were successful im.handle_write_reqs( - peer, - None, + &handler, input1, &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ], ); - assert_eq!(val1, im.echo_cluster(0).att_write); - assert_eq!(val1, im.echo_cluster(1).att_write); + assert_eq!(val1, handler.echo_cluster(0).att_write.get()); + assert_eq!(val1, handler.echo_cluster(1).att_write.get()); } #[test] @@ -253,25 +250,26 @@ fn exact_write_attribute() { )]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - im.handle_write_reqs(peer, None, input, expected_fail); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_fail); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - im.handle_write_reqs(peer, None, input, expected_success); - assert_eq!(val0, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_success); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -303,19 +301,20 @@ fn exact_write_attribute_noc_cat() { )]; let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; - let peer = 98765; /* CAT in NOC is 1 more, in version, than that in ACL */ let noc_cat = gen_noc_cat(0xABCD, 2); let cat_in_acl = gen_noc_cat(0xABCD, 1); let cat_ids = [noc_cat, 0, 0]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new(cat_ids); + let handler = im.handler(); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - im.handle_write_reqs(peer, Some(&cat_ids), input, expected_fail); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_fail); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); @@ -324,8 +323,8 @@ fn exact_write_attribute_noc_cat() { // Test 1: Exact write to an attribute with permission should grant // access - im.handle_write_reqs(peer, Some(&cat_ids), input, expected_success); - assert_eq!(val0, im.echo_cluster(0).att_write); + im.handle_write_reqs(&handler, input, expected_success); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -347,21 +346,18 @@ fn insufficient_perms_write() { EncodeValue::Closure(&attr_data0), )]; - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); - acl.add_subject(peer).unwrap(); + acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test: Not enough permission should return error im.handle_write_reqs( - peer, - None, + &handler, input0, &[AttrStatus::new( &ep0_att, @@ -369,7 +365,10 @@ fn insufficient_perms_write() { 0, )], ); - assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); + assert_eq!( + ATTR_WRITE_DEFAULT_VALUE, + handler.echo_cluster(0).att_write.get() + ); } #[test] @@ -381,10 +380,9 @@ fn insufficient_perms_write() { /// - Write Attr to Echo Cluster again (successful this time) fn write_with_runtime_acl_add() { init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { @@ -403,7 +401,7 @@ fn write_with_runtime_acl_add() { // Create ACL to allow our peer ADMIN on everything let mut allow_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - allow_acl.add_subject(peer).unwrap(); + allow_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); let acl_att = GenericPath::new( Some(0), @@ -418,7 +416,7 @@ fn write_with_runtime_acl_add() { // Create ACL that only allows write to the ACL Cluster let mut basic_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - basic_acl.add_subject(peer).unwrap(); + basic_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); basic_acl .add_target(Target::new(Some(0), Some(access_control::ID), None)) .unwrap(); @@ -426,8 +424,7 @@ fn write_with_runtime_acl_add() { // Test: deny write (with error), then ACL is added, then allow write im.handle_write_reqs( - peer, - None, + &handler, // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute &[input0.clone(), acl_input, input0], &[ @@ -436,7 +433,7 @@ fn write_with_runtime_acl_add() { AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), ], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -448,10 +445,9 @@ fn test_read_data_ver() { // - wildcard endpoint, att1 // - 2 responses are expected init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); @@ -482,10 +478,11 @@ fn test_read_data_ver() { ElementType::U16(0x1234) ), ]; - let mut out_buf = [0u8; 400]; + + let mut out = heapless::Vec::new(); // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 - let received = im.gen_read_reqs_output(peer, input, None, &mut out_buf); + let received = im.gen_read_reqs_output::<1>(&handler, input, None, &mut out); assert_attr_report(&received, expected); let data_ver_cluster_at_0 = received @@ -507,11 +504,12 @@ fn test_read_data_ver() { }]; // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved - let received = im.gen_read_reqs_output( - peer, + let mut out = heapless::Vec::new(); + let received = im.gen_read_reqs_output::<1>( + &handler, input, Some(TLVArray::Slice(&dataver_filter)), - &mut out_buf, + &mut out, ); let expected_only_one = &[attr_data_path!( GenericPath::new( @@ -532,10 +530,10 @@ fn test_read_data_ver() { ); let input = &[AttrPath::new(&ep0_att1)]; let received = im.gen_read_reqs_output( - peer, + &handler, input, Some(TLVArray::Slice(&dataver_filter)), - &mut out_buf, + &mut out, ); let expected_error = &[]; @@ -551,10 +549,9 @@ fn test_write_data_ver() { // - wildcard endpoint, att1 // - 2 responses are expected init_env_logger(); - let peer = 98765; - let mut mdns = DummyMdns {}; - let matter = matter(&mut mdns); - let mut im = ImEngine::new(&matter); + + let im = ImEngine::new_default(); + let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); @@ -576,7 +573,7 @@ fn test_write_data_ver() { let attr_data0 = EncodeValue::Value(&val0); let attr_data1 = EncodeValue::Value(&val1); - let initial_data_ver = im.echo_cluster(0).data_ver.get(); + let initial_data_ver = handler.echo_cluster(0).data_ver.get(); // Test 1: Write with correct dataversion should succeed let input_correct_dataver = &[AttrData::new( @@ -585,12 +582,11 @@ fn test_write_data_ver() { attr_data0, )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); // Test 2: Write with incorrect dataversion should fail // Now the data version would have incremented due to the previous write @@ -600,8 +596,7 @@ fn test_write_data_ver() { attr_data1.clone(), )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new( &ep0_attwrite, @@ -609,12 +604,12 @@ fn test_write_data_ver() { 0, )], ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); // Test 3: Wildcard write with incorrect dataversion should ignore that cluster // In this case, while the data version is correct for endpoint 0, the endpoint 1's // data version would not match - let new_data_ver = im.echo_cluster(0).data_ver.get(); + let new_data_ver = handler.echo_cluster(0).data_ver.get(); let input_correct_dataver = &[AttrData::new( Some(new_data_ver), @@ -622,12 +617,11 @@ fn test_write_data_ver() { attr_data1, )]; im.handle_write_reqs( - peer, - None, + &handler, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(val1, im.echo_cluster(0).att_write); + assert_eq!(val1, handler.echo_cluster(0).att_write.get()); assert_eq!(initial_data_ver + 1, new_data_ver); } diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index 636c9c0..12d4a5d 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -22,13 +22,12 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrStatus}, messages::GenericPath, }, - mdns::DummyMdns, tlv::Nullable, }; use crate::common::{ echo_cluster::{self, TestChecker}, - im_engine::{matter, ImEngine}, + im_engine::ImEngine, init_env_logger, }; @@ -65,8 +64,8 @@ fn attr_list_ops() { EncodeValue::Value(&val0), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), None, None, None, None], tc.write_list); @@ -79,8 +78,8 @@ fn attr_list_ops() { EncodeValue::Value(&val1), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), Some(val1), None, None, None], tc.write_list); @@ -94,8 +93,8 @@ fn attr_list_ops() { EncodeValue::Value(&val0), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(val0), Some(val0), None, None, None], tc.write_list); @@ -105,8 +104,8 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(0)); let input = &[AttrData::new(None, att_path.clone(), delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([None, Some(val0), None, None, None], tc.write_list); @@ -121,8 +120,8 @@ fn attr_list_ops() { EncodeValue::Value(&overwrite_val), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([Some(20), Some(21), None, None, None], tc.write_list); @@ -132,8 +131,8 @@ fn attr_list_ops() { att_path.list_index = None; let input = &[AttrData::new(None, att_path, delete_all)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::write_reqs(input, expected); { let tc = tc_handle.lock().unwrap(); assert_eq!([None, None, None, None, None], tc.write_list); diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 7d18526..87bd96d 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -25,18 +25,12 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, messages::GenericPath, }, - mdns::DummyMdns, tlv::{ElementType, TLVElement, TLVWriter, TagType}, }; use crate::{ attr_data, attr_data_path, attr_status, - common::{ - attributes::*, - echo_cluster, - im_engine::{matter, ImEngine}, - init_env_logger, - }, + common::{attributes::*, echo_cluster, im_engine::ImEngine, init_env_logger}, }; #[test] @@ -75,7 +69,7 @@ fn test_read_success() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -122,7 +116,7 @@ fn test_read_unsupported_fields() { attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster), attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -153,7 +147,7 @@ fn test_read_wc_endpoint_all_have_clusters() { ElementType::U16(0x1234) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -178,7 +172,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() { ), ElementType::False )]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -285,7 +279,7 @@ fn test_read_wc_endpoint_wc_attribute() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); + ImEngine::read_reqs(input, expected); } #[test] @@ -331,11 +325,14 @@ fn test_write_success() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); - assert_eq!(val0, im.echo_cluster(0).att_write); - assert_eq!(val1, im.echo_cluster(1).att_write); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); + assert_eq!(val1, handler.echo_cluster(1).att_write.get()); } #[test] @@ -375,10 +372,13 @@ fn test_write_wc_endpoint() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); - assert_eq!(val0, im.echo_cluster(0).att_write); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -467,11 +467,14 @@ fn test_write_unsupported_fields() { AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0), AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0), ]; - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_write_reqs(&matter, input, expected); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); + im.handle_write_reqs(&handler, input, expected); + assert_eq!( echo_cluster::ATTR_WRITE_DEFAULT_VALUE, - im.echo_cluster(0).att_write + handler.echo_cluster(0).att_write.get() ); } diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index 0d9c0c3..ee91771 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -17,12 +17,7 @@ use crate::{ cmd_data, - common::{ - commands::*, - echo_cluster, - im_engine::{matter, ImEngine}, - init_env_logger, - }, + common::{commands::*, echo_cluster, im_engine::ImEngine, init_env_logger}, echo_req, echo_resp, }; @@ -32,7 +27,6 @@ use matter::{ core::IMStatusCode, messages::ib::{CmdData, CmdPath, CmdStatus}, }, - mdns::DummyMdns, }; #[test] @@ -44,7 +38,7 @@ fn test_invoke_cmds_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -99,7 +93,7 @@ fn test_invoke_cmds_unsupported_fields() { 0, )), ]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -114,7 +108,7 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() { ); let input = &[cmd_data!(path, 5)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } #[test] @@ -139,5 +133,5 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { IMStatusCode::Success, 0, ))]; - ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); + ImEngine::commands(input, expected); } diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index 21c2559..e8382e0 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -30,13 +30,7 @@ use matter::{ }, messages::{msg::SubscribeReq, GenericPath}, }, - mdns::DummyMdns, - tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV}, - transport::{ - exchange::{self, Exchange}, - packet::MAX_RX_BUF_SIZE, - }, - Matter, + tlv::{self, ElementType, FromTLV, TLVElement, TagType}, }; use crate::{ @@ -44,35 +38,11 @@ use crate::{ common::{ attributes::*, echo_cluster as echo, - im_engine::{matter, ImEngine, ImInput}, + im_engine::{ImEngine, ImInput}, init_env_logger, }, }; -pub struct LongRead<'a> { - im_engine: ImEngine<'a>, -} - -impl<'a> LongRead<'a> { - pub fn new(matter: &'a Matter<'a>) -> Self { - let mut im_engine = ImEngine::new(matter); - // Use the same exchange for all parts of the transaction - im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); - Self { im_engine } - } - - pub fn process<'p>( - &mut self, - action: OpCode, - data: &dyn ToTLV, - data_out: &'p mut [u8], - ) -> (u8, &'p [u8]) { - let input = ImInput::new(action, data); - let (response, output) = self.im_engine.process(&input, data_out); - (response, output) - } -} - fn wildcard_read_resp(part: u8) -> Vec> { // For brevity, we only check the AttrPath, not the actual 'data' let dont_care = ElementType::U8(0); @@ -215,6 +185,9 @@ fn wildcard_read_resp(part: u8) -> Vec> { acl::AttributesDiscriminants::Extension, dont_care.clone() ), + ]; + + let part2 = vec![ attr_data!( 0, 31, @@ -266,9 +239,6 @@ fn wildcard_read_resp(part: u8) -> Vec> { descriptor::Attributes::DeviceTypeList, dont_care.clone() ), - ]; - - let part2 = vec![ attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()), attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()), @@ -318,74 +288,103 @@ fn wildcard_read_resp(part: u8) -> Vec> { fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut lr = LongRead::new(&matter); - let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; + + let mut out = heapless::Vec::<_, 3>::new(); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); let wc_path = GenericPath::new(None, None, None); let read_all = [AttrPath::new(&wc_path)]; let read_req = ReadReq::new(true).set_attr_requests(&read_all); let expected_part1 = wildcard_read_resp(1); - let (out_code, out_data) = lr.process(OpCode::ReadRequest, &read_req, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - assert_eq!(out_code, OpCode::ReportData as u8); - // Ask for the next read by sending a status report let status_report = StatusResp { status: IMStatusCode::Success, }; let expected_part2 = wildcard_read_resp(2); - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); + + im.process( + &handler, + &[ + &ImInput::new(OpCode::ReadRequest, &read_req), + &ImInput::new(OpCode::StatusResponse, &status_report), + ], + &mut out, + ) + .unwrap(); + + assert_eq!(out.len(), 2); + + assert_eq!(out[0].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); + let report_data = ReportDataMsg::from_tlv(&root).unwrap(); + assert_attr_report_skip_data(&report_data, &expected_part1); + assert_eq!(report_data.more_chunks, Some(true)); + + assert_eq!(out[1].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[1].data).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap(); assert_attr_report_skip_data(&report_data, &expected_part2); assert_eq!(report_data.more_chunks, None); - assert_eq!(out_code, OpCode::ReportData as u8); } #[test] fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let mut lr = LongRead::new(&matter); - let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; + + let mut out = heapless::Vec::<_, 3>::new(); + let im = ImEngine::new_default(); + let handler = im.handler(); + + im.add_default_acl(); let wc_path = GenericPath::new(None, None, None); let read_all = [AttrPath::new(&wc_path)]; let subs_req = SubscribeReq::new(true, 1, 20).set_attr_requests(&read_all); let expected_part1 = wildcard_read_resp(1); - let (out_code, out_data) = lr.process(OpCode::SubscribeRequest, &subs_req, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - assert_eq!(out_code, OpCode::ReportData as u8); - // Ask for the next read by sending a status report let status_report = StatusResp { status: IMStatusCode::Success, }; let expected_part2 = wildcard_read_resp(2); - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - let root = tlv::get_root_node_struct(out_data).unwrap(); + + im.process( + &handler, + &[ + &ImInput::new(OpCode::SubscribeRequest, &subs_req), + &ImInput::new(OpCode::StatusResponse, &status_report), + &ImInput::new(OpCode::StatusResponse, &status_report), + ], + &mut out, + ) + .unwrap(); + + assert_eq!(out.len(), 3); + + assert_eq!(out[0].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[0].data).unwrap(); + let report_data = ReportDataMsg::from_tlv(&root).unwrap(); + assert_attr_report_skip_data(&report_data, &expected_part1); + assert_eq!(report_data.more_chunks, Some(true)); + + assert_eq!(out[1].action, OpCode::ReportData); + + let root = tlv::get_root_node_struct(&out[1].data).unwrap(); let report_data = ReportDataMsg::from_tlv(&root).unwrap(); assert_attr_report_skip_data(&report_data, &expected_part2); assert_eq!(report_data.more_chunks, None); - assert_eq!(out_code, OpCode::ReportData as u8); - // Finally confirm subscription - let (out_code, out_data) = lr.process(OpCode::StatusResponse, &status_report, &mut output); - tlv::print_tlv_list(out_data); - let root = tlv::get_root_node_struct(out_data).unwrap(); + assert_eq!(out[2].action, OpCode::SubscribeResponse); + + let root = tlv::get_root_node_struct(&out[2].data).unwrap(); let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); - assert_eq!(out_code, OpCode::SubscribeResponse as u8); assert_eq!(subs_resp.subs_id, 1); } diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index 3f44190..e4eb960 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -22,7 +22,6 @@ use matter::{ messages::ib::{AttrData, AttrPath, AttrStatus}, messages::{ib::CmdData, ib::CmdPath, GenericPath}, }, - mdns::DummyMdns, tlv::TLVWriter, }; @@ -31,7 +30,7 @@ use crate::{ commands::*, echo_cluster, handlers::{TimedInvResponse, WriteResponse}, - im_engine::{matter, ImEngine}, + im_engine::ImEngine, init_env_logger, }, echo_req, echo_resp, @@ -75,25 +74,20 @@ fn test_timed_write_fail_and_success() { ]; // Test with incorrect handling - ImEngine::new_with_timed_write_reqs( - &matter(&mut DummyMdns), - input, - &WriteResponse::TransactionError, - 400, - 500, - ); + ImEngine::timed_write_reqs(input, &WriteResponse::TransactionError, 100, 500); // Test with correct handling - let mut mdns = DummyMdns; - let matter = matter(&mut mdns); - let im = ImEngine::new_with_timed_write_reqs( - &matter, + let im = ImEngine::new_default(); + let handler = im.handler(); + im.add_default_acl(); + im.handle_timed_write_reqs( + &handler, input, &WriteResponse::TransactionSuccess(expected), 400, 0, ); - assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } #[test] @@ -103,8 +97,7 @@ fn test_timed_cmd_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionSuccess(expected), 400, @@ -119,11 +112,10 @@ fn test_timed_cmd_timeout() { init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::Timeout), - 400, + 100, 500, true, ); @@ -135,8 +127,7 @@ fn test_timed_cmd_timedout_mismatch() { init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 400, @@ -145,8 +136,7 @@ fn test_timed_cmd_timedout_mismatch() { ); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::new_with_timed_commands( - &matter(&mut DummyMdns), + ImEngine::timed_commands( input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 0, diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs deleted file mode 100644 index 9642ab2..0000000 --- a/matter/tests/interaction_model.rs +++ /dev/null @@ -1,152 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use matter::data_model::core::DataHandler; -use matter::error::Error; -use matter::interaction_model::core::Interaction; -use matter::interaction_model::core::InteractionModel; -use matter::interaction_model::core::OpCode; -use matter::interaction_model::core::Transaction; -use matter::transport::exchange::Exchange; -use matter::transport::exchange::ExchangeCtx; -use matter::transport::network::Address; -use matter::transport::network::IpAddr; -use matter::transport::network::Ipv4Addr; -use matter::transport::network::SocketAddr; -use matter::transport::packet::Packet; -use matter::transport::packet::MAX_RX_BUF_SIZE; -use matter::transport::packet::MAX_TX_BUF_SIZE; -use matter::transport::proto_ctx::ProtoCtx; -use matter::transport::session::SessionMgr; -use matter::utils::epoch::dummy_epoch; -use matter::utils::rand::dummy_rand; - -struct Node { - pub endpoint: u16, - pub cluster: u32, - pub command: u16, - pub variable: u8, -} - -struct DataModel { - node: Node, -} - -impl DataModel { - pub fn new(node: Node) -> Self { - DataModel { node } - } -} - -impl DataHandler for DataModel { - fn handle( - &mut self, - interaction: Interaction, - _tx: &mut Packet, - _transaction: &mut Transaction, - ) -> Result { - if let Interaction::Invoke(req) = interaction { - if let Some(inv_requests) = &req.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let common_data = &mut self.node; - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); - } - } - } - - Ok(false) - } -} - -fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataModel, usize) { - let data_model = DataModel::new(Node { - endpoint: 0, - cluster: 0, - command: 0, - variable: 0, - }); - let mut interaction_model = InteractionModel(data_model); - let mut exch: Exchange = Default::default(); - let mut sess_mgr = SessionMgr::new(dummy_epoch, dummy_rand); - let sess_idx = sess_mgr - .get_or_add( - 0, - Address::Udp(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - 5542, - )), - None, - false, - ) - .unwrap(); - let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { - exch: &mut exch, - sess, - epoch: dummy_epoch, - }; - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - let mut rx = Packet::new_rx(&mut rx_buf); - let mut tx = Packet::new_tx(&mut tx_buf); - // Create fake rx packet - rx.set_proto_id(0x01); - rx.set_proto_opcode(action as u8); - rx.peer = Address::default(); - let in_data_len = data_in.len(); - let rx_buf = rx.as_mut_slice(); - rx_buf[..in_data_len].copy_from_slice(data_in); - - let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - - interaction_model.handle(&mut ctx).unwrap(); - - let out_len = ctx.tx.as_mut_slice().len(); - data_out[..out_len].copy_from_slice(ctx.tx.as_mut_slice()); - (interaction_model.0, out_len) -} - -#[test] -fn test_valid_invoke_cmd() -> Result<(), Error> { - // An invoke command for endpoint 0, cluster 49, command 12 and a u8 variable value of 0x05 - - let b = [ - 0x15, 0x28, 0x00, 0x28, 0x01, 0x36, 0x02, 0x15, 0x37, 0x00, 0x25, 0x00, 0x00, 0x00, 0x26, - 0x01, 0x31, 0x00, 0x00, 0x00, 0x26, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x18, 0x35, 0x01, 0x24, - 0x00, 0x05, 0x18, 0x18, 0x18, 0x18, - ]; - - let mut out_buf: [u8; 20] = [0; 20]; - - let (data_model, _) = handle_data(OpCode::InvokeRequest, &b, &mut out_buf); - let data = &data_model.node; - assert_eq!(data.endpoint, 0); - assert_eq!(data.cluster, 49); - assert_eq!(data.command, 12); - assert_eq!(data.variable, 5); - Ok(()) -} diff --git a/sdkconfig.defaults b/sdkconfig.defaults new file mode 100644 index 0000000..6ccea50 --- /dev/null +++ b/sdkconfig.defaults @@ -0,0 +1,6 @@ +# Workaround for https://github.com/espressif/esp-idf/issues/7631 +CONFIG_MBEDTLS_CERTIFICATE_BUNDLE=n +CONFIG_MBEDTLS_CERTIFICATE_BUNDLE_DEFAULT_FULL=n + +# Examples often require a larger than the default stack size for the main thread. +CONFIG_ESP_MAIN_TASK_STACK_SIZE=10000