diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 2769010..fc38d45 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,7 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls"] +default = ["std", "crypto_mbedtls", "nightly"] std = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 8f965e1..2bfd0d6 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -179,14 +179,14 @@ impl<'a> Accessor<'a> { let _ = subject.add_catid(i); } } - Accessor::new(c.fab_idx, subject, AuthMode::Case, &acl_mgr) + Accessor::new(c.fab_idx, subject, AuthMode::Case, acl_mgr) } SessionMode::Pase => { - Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, &acl_mgr) + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr) } SessionMode::PlainText => { - Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, &acl_mgr) + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, acl_mgr) } } } @@ -514,7 +514,7 @@ impl AclMgr { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice())?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())?; self.changed = false; } @@ -546,7 +546,7 @@ impl AclMgr { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice()).await?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()).await?; self.changed = false; } @@ -561,10 +561,7 @@ impl AclMgr { { let mut buf = [0u8; ACL_KV_MAX_SIZE]; let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; - let root = TLVList::new(&acl_tlvs) - .iter() - .next() - .ok_or(Error::Invalid)?; + let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 664125b..b928329 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -597,7 +597,7 @@ impl Cert { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.to_tlv(&mut tw, TagType::Anonymous)?; - Ok(wb.into_slice().len()) + Ok(wb.as_slice().len()) } pub fn as_asn1(&self, buf: &mut [u8]) -> Result { @@ -823,7 +823,7 @@ mod tests { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!(*input, wb.into_slice()); + assert_eq!(*input, wb.as_slice()); } } diff --git a/matter/src/data_model/cluster_basic_information.rs b/matter/src/data_model/cluster_basic_information.rs index 71c0722..dcc7028 100644 --- a/matter/src/data_model/cluster_basic_information.rs +++ b/matter/src/data_model/cluster_basic_information.rs @@ -19,11 +19,11 @@ use core::convert::TryInto; use super::objects::*; use crate::{attribute_enum, error::Error, utils::rand::Rand}; -use strum::{EnumDiscriminants, FromRepr}; +use strum::FromRepr; pub const ID: u32 = 0x0028; -#[derive(Clone, Copy, Debug, FromRepr, EnumDiscriminants)] +#[derive(Clone, Copy, Debug, FromRepr)] #[repr(u16)] pub enum Attributes { DMRevision(AttrType) = 0, @@ -37,6 +37,16 @@ pub enum Attributes { attribute_enum!(Attributes); +pub enum AttributesDiscriminants { + DMRevision = 0, + VendorId = 2, + ProductId = 4, + HwVer = 7, + SwVer = 9, + SwVerString = 0xa, + SerialNo = 0x0f, +} + #[derive(Default)] pub struct BasicInfoConfig<'a> { pub vid: u16, diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 0ca8379..b53f955 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -15,17 +15,26 @@ * limitations under the License. */ -use core::cell::RefCell; +use core::{ + cell::RefCell, + sync::atomic::{AtomicU32, Ordering}, +}; use super::objects::*; use crate::{ acl::{Accessor, AclMgr}, error::*, - interaction_model::core::{Interaction, Transaction}, - tlv::TLVWriter, + interaction_model::{ + core::{Interaction, Transaction}, + messages::msg::SubscribeResp, + }, + tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, }; +// TODO: For now... +static SUBS_ID: AtomicU32 = AtomicU32::new(1); + pub struct DataModel<'a, T> { pub acl_mgr: &'a RefCell, pub node: &'a Node<'a>, @@ -43,7 +52,7 @@ impl<'a, T> DataModel<'a, T> { pub fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result @@ -55,44 +64,89 @@ impl<'a, T> DataModel<'a, T> { match interaction { Interaction::Read(req) => { - for item in self.node.read(req, &accessor) { - AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + let mut resume_path = None; + + for item in self.node.read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } } + + req.complete(tx, transaction, resume_path) } Interaction::Write(req) => { - for item in self.node.write(req, &accessor) { + for item in self.node.write(&req, &accessor) { AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; } + + req.complete(tx, transaction) } Interaction::Invoke(req) => { - for item in self.node.invoke(req, &accessor) { + for item in self.node.invoke(&req, &accessor) { CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; } + + req.complete(tx, transaction) } Interaction::Subscribe(req) => { - for item in self.node.subscribing_read(req, &accessor) { - AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; - } - } - Interaction::Status(_resp) => { - todo!() - // for item in self.node.subscribing_read(req, &accessor) { - // AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; - // } - } - Interaction::Timed(_) => (), - } + let mut resume_path = None; - interaction.complete_tx(tx, transaction) + for item in self.node.subscribing_read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::Timed(_) => Ok(false), + Interaction::ResumeRead(req) => { + let mut resume_path = None; + + for item in self.node.resume_read(&req, &accessor) { + if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::ResumeSubscribe(req) => { + let mut resume_path = None; + + if req.resume_path.is_some() { + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read(item, &self.handler, &mut tw)? + { + resume_path = Some(path); + break; + } + } + } else { + // TODO + let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); + resp.to_tlv(&mut tw, TagType::Anonymous)?; + } + + req.complete(tx, transaction, resume_path) + } + } } #[cfg(feature = "nightly")] pub async fn handle_async<'p>( &mut self, - interaction: &Interaction<'_>, + interaction: Interaction<'_>, tx: &'p mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> + ) -> Result where T: super::objects::asynch::AsyncHandler, { @@ -101,32 +155,91 @@ impl<'a, T> DataModel<'a, T> { match interaction { Interaction::Read(req) => { - for item in self.node.read(req, &accessor) { - AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?; + let mut resume_path = None; + + for item in self.node.read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } } + + req.complete(tx, transaction, resume_path) } Interaction::Write(req) => { - for item in self.node.write(req, &accessor) { + for item in self.node.write(&req, &accessor) { AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; } + + req.complete(tx, transaction) } Interaction::Invoke(req) => { - for item in self.node.invoke(req, &accessor) { + for item in self.node.invoke(&req, &accessor) { CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) .await?; } - } - Interaction::Timed(_) => (), - } - interaction.complete_tx(tx, transaction) + req.complete(tx, transaction) + } + Interaction::Subscribe(req) => { + let mut resume_path = None; + + for item in self.node.subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::Timed(_) => Ok(false), + Interaction::ResumeRead(req) => { + let mut resume_path = None; + + for item in self.node.resume_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + + req.complete(tx, transaction, resume_path) + } + Interaction::ResumeSubscribe(req) => { + let mut resume_path = None; + + if req.resume_path.is_some() { + for item in self.node.resume_subscribing_read(&req, &accessor) { + if let Some(path) = + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await? + { + resume_path = Some(path); + break; + } + } + } else { + // TODO + let resp = SubscribeResp::new(SUBS_ID.fetch_add(1, Ordering::SeqCst), 40); + resp.to_tlv(&mut tw, TagType::Anonymous)?; + } + + req.complete(tx, transaction, resume_path) + } + } } } pub trait DataHandler { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result; @@ -138,7 +251,7 @@ where { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result { @@ -152,7 +265,7 @@ where { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, tx: &mut Packet, transaction: &mut Transaction, ) -> Result { @@ -172,24 +285,24 @@ pub mod asynch { use super::DataModel; pub trait AsyncDataHandler { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction, - tx: &'p mut Packet, + interaction: Interaction<'_>, + tx: &mut Packet, transaction: &mut Transaction, - ) -> Result, Error>; + ) -> Result; } impl AsyncDataHandler for &mut T where T: AsyncDataHandler, { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction<'_>, - tx: &'p mut Packet<'_>, + interaction: Interaction<'_>, + tx: &mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> { + ) -> Result { (**self).handle(interaction, tx, transaction).await } } @@ -198,12 +311,12 @@ pub mod asynch { where T: AsyncHandler, { - async fn handle<'p>( + async fn handle( &mut self, - interaction: &Interaction<'_>, - tx: &'p mut Packet<'_>, + interaction: Interaction<'_>, + tx: &mut Packet<'_>, transaction: &mut Transaction<'_, '_>, - ) -> Result, Error> { + ) -> Result { DataModel::handle_async(self, interaction, tx, transaction).await } } diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 90c6835..3818f93 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -64,6 +64,7 @@ pub const ATTRIBUTE_LIST: Attribute = Attribute::new( // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // methods? /// The Attribute Details structure records the details about the attribute under consideration. +#[derive(Debug)] pub struct AttrDetails<'a> { pub node: &'a Node<'a>, /// The actual endpoint ID @@ -129,6 +130,7 @@ impl<'a> AttrDetails<'a> { } } +#[derive(Debug)] pub struct CmdDetails<'a> { pub node: &'a Node<'a>, pub endpoint_id: EndptId, @@ -208,49 +210,23 @@ impl<'a> Cluster<'a> { } } - pub(crate) fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, - ep: EndptId, + pub fn match_attributes( + &self, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.attributes .iter() .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true)) - .filter(move |attribute| { - let mut access_req = AccessReq::new( - accessor, - GenericPath::new(Some(ep), Some(self.id), Some(attribute.id as _)), - if write { Access::WRITE } else { Access::READ }, - ); - self.check_attr_access(&mut access_req, attribute.access) - .is_ok() - }) - .map(|attribute| attribute.id) } - pub fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, - ep: EndptId, - cmd: Option, - ) -> impl Iterator + 'm { + pub fn match_commands(&self, cmd: Option) -> impl Iterator + '_ { self.commands .iter() .filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true)) - .filter(move |id| { - let mut access_req = AccessReq::new( - accessor, - GenericPath::new(Some(ep), Some(self.id), Some(**id as _)), - Access::WRITE, - ); - self.check_cmd_access(&mut access_req).is_ok() - }) .copied() } - pub(crate) fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, ep: EndptId, @@ -263,16 +239,15 @@ impl<'a> Cluster<'a> { .find(|attribute| attribute.id == attr) .ok_or(IMStatusCode::UnsupportedAttribute)?; - let mut access_req = AccessReq::new( + Self::check_attr_access( accessor, GenericPath::new(Some(ep), Some(self.id), Some(attr as _)), - if write { Access::WRITE } else { Access::READ }, - ); - - self.check_attr_access(&mut access_req, attribute.access) + write, + attribute.access, + ) } - pub(crate) fn check_command( + pub fn check_command( &self, accessor: &Accessor, ep: EndptId, @@ -283,20 +258,24 @@ impl<'a> Cluster<'a> { .find(|id| **id == cmd) .ok_or(IMStatusCode::UnsupportedCommand)?; - let mut access_req = AccessReq::new( + Self::check_cmd_access( accessor, - GenericPath::new(Some(ep), Some(self.id), Some(cmd as _)), - Access::WRITE, - ); - - self.check_cmd_access(&mut access_req) + GenericPath::new(Some(ep), Some(self.id), Some(cmd)), + ) } - fn check_attr_access( - &self, - access_req: &mut AccessReq, + pub(crate) fn check_attr_access( + accessor: &Accessor, + path: GenericPath, + write: bool, target_perms: Access, ) -> Result<(), IMStatusCode> { + let mut access_req = AccessReq::new( + accessor, + path, + if write { Access::WRITE } else { Access::READ }, + ); + if !target_perms.contains(access_req.operation()) { Err(if matches!(access_req.operation(), Access::WRITE) { IMStatusCode::UnsupportedWrite @@ -313,7 +292,12 @@ impl<'a> Cluster<'a> { } } - fn check_cmd_access(&self, access_req: &mut AccessReq) -> Result<(), IMStatusCode> { + pub(crate) fn check_cmd_access( + accessor: &Accessor, + path: GenericPath, + ) -> Result<(), IMStatusCode> { + let mut access_req = AccessReq::new(accessor, path, Access::WRITE); + access_req.set_target_perms( Access::WRITE .union(Access::NEED_OPERATE) diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index 39d2ba6..b4066e6 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -23,6 +23,7 @@ use crate::interaction_model::core::{IMStatusCode, Transaction}; use crate::interaction_model::messages::ib::{ AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, }; +use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; use crate::{ error::Error, @@ -127,13 +128,14 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { item: Result, handler: &T, tw: &mut TLVWriter, - ) -> Result<(), Error> { + ) -> Result, Error> { let status = match item { Ok(attr) => { let encoder = AttrDataEncoder::new(&attr, tw); match handler.read(&attr, encoder) { Ok(()) => None, + Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), Err(error) => attr.status(error.into())?, } } @@ -144,7 +146,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(()) + Ok(None) } pub fn handle_write( @@ -172,13 +174,14 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { item: Result, AttrStatus>, handler: &T, tw: &mut TLVWriter<'_, '_>, - ) -> Result<(), Error> { + ) -> Result, Error> { let status = match item { Ok(attr) => { let encoder = AttrDataEncoder::new(&attr, tw); match handler.read(&attr, encoder).await { Ok(()) => None, + Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), Err(error) => attr.status(error.into())?, } } @@ -189,7 +192,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; } - Ok(()) + Ok(None) } #[cfg(feature = "nightly")] diff --git a/matter/src/data_model/objects/endpoint.rs b/matter/src/data_model/objects/endpoint.rs index d0a4fdd..05878ed 100644 --- a/matter/src/data_model/objects/endpoint.rs +++ b/matter/src/data_model/objects/endpoint.rs @@ -19,7 +19,7 @@ use crate::{acl::Accessor, interaction_model::core::IMStatusCode}; use core::fmt; -use super::{AttrId, Cluster, ClusterId, CmdId, DeviceType, EndptId}; +use super::{AttrId, Attribute, Cluster, ClusterId, CmdId, DeviceType, EndptId}; #[derive(Debug, Clone)] pub struct Endpoint<'a> { @@ -29,34 +29,28 @@ pub struct Endpoint<'a> { } impl<'a> Endpoint<'a> { - pub(crate) fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_attributes( + &self, cl: Option, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_clusters(cl).flat_map(move |cluster| { cluster - .match_attributes(accessor, self.id, attr, write) - .map(move |attr| (cluster.id, attr)) + .match_attributes(attr) + .map(move |attr| (cluster, attr)) }) } - pub(crate) fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_commands( + &self, cl: Option, cmd: Option, - ) -> impl Iterator + 'm { - self.match_clusters(cl).flat_map(move |cluster| { - cluster - .match_commands(accessor, self.id, cmd) - .map(move |cmd| (cluster.id, cmd)) - }) + ) -> impl Iterator + '_ { + self.match_clusters(cl) + .flat_map(move |cluster| cluster.match_commands(cmd).map(move |cmd| (cluster, cmd))) } - pub(crate) fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, cl: ClusterId, @@ -67,7 +61,7 @@ impl<'a> Endpoint<'a> { .and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write)) } - pub(crate) fn check_command( + pub fn check_command( &self, accessor: &Accessor, cl: ClusterId, @@ -77,13 +71,13 @@ impl<'a> Endpoint<'a> { .and_then(|cluster| cluster.check_command(accessor, self.id, cmd)) } - fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { + pub fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { self.clusters .iter() .filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true)) } - fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { + pub fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { self.clusters .iter() .find(|cluster| cluster.id == cl) diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 052d690..7758427 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -186,9 +186,16 @@ macro_rules! handler_chain_type { ($h:ty) => { $crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler> }; - ($h1:ty, $($rest:ty),+) => { + ($h1:ty $(, $rest:ty)+) => { $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)> }; + + ($h:ty | $f:ty) => { + $crate::data_model::objects::ChainedHandler<$h, $f> + }; + ($h1:ty $(, $rest:ty)+ | $f:ty) => { + $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+ | $f)> + }; } #[cfg(feature = "nightly")] diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index 4ec1765..3ee3af2 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -19,7 +19,7 @@ use crate::{ acl::Accessor, data_model::objects::Endpoint, interaction_model::{ - core::IMStatusCode, + core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq}, messages::{ ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, msg::{InvReq, ReadReq, SubscribeReq, WriteReq}, @@ -27,16 +27,16 @@ use crate::{ }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVArray, TLVElement}, + tlv::{TLVArray, TLVArrayIter, TLVElement}, }; use core::{ fmt, iter::{once, Once}, }; -use super::{AttrDetails, AttrId, ClusterId, CmdDetails, CmdId, EndptId}; +use super::{AttrDetails, AttrId, Attribute, Cluster, ClusterId, CmdDetails, CmdId, EndptId}; -enum WildcardIter { +pub enum WildcardIter { None, Single(Once), Wildcard(T), @@ -57,6 +57,41 @@ where } } +pub trait Iterable { + type Item; + + type Iterator<'a>: Iterator + where + Self: 'a; + + fn iter(&self) -> Self::Iterator<'_>; +} + +impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> { + type Item = DataVersionFilter; + + type Iterator<'i> = WildcardIter, DataVersionFilter> where Self: 'i; + + fn iter(&self) -> Self::Iterator<'_> { + if let Some(filters) = self { + WildcardIter::Wildcard(filters.iter()) + } else { + WildcardIter::None + } + } +} + +impl<'a> Iterable for &'a [DataVersionFilter] { + type Item = DataVersionFilter; + + type Iterator<'i> = core::iter::Copied> where Self: 'i; + + fn iter(&self) -> Self::Iterator<'_> { + let slice: &[DataVersionFilter] = self; + slice.iter().copied() + } +} + #[derive(Debug, Clone)] pub struct Node<'a> { pub id: u16, @@ -73,10 +108,30 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.attr_requests.as_ref(), + req.attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()), req.dataver_filters.as_ref(), req.fabric_filtered, accessor, + None, + ) + } + + pub fn resume_read<'s, 'm>( + &'s self, + req: &'m ResumeReadReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + self.read_attr_requests( + req.paths.iter().copied(), + req.filters.as_slice(), + req.fabric_filtered, + accessor, + Some(req.resume_path), ) } @@ -89,60 +144,115 @@ impl<'a> Node<'a> { 's: 'm, { self.read_attr_requests( - req.attr_requests.as_ref(), + req.attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()), req.dataver_filters.as_ref(), req.fabric_filtered, accessor, + None, ) } - fn read_attr_requests<'s, 'm>( + pub fn resume_subscribing_read<'s, 'm>( &'s self, - attr_requests: Option<&'m TLVArray>, - dataver_filters: Option<&'m TLVArray>, - fabric_filtered: bool, + req: &'m ResumeSubscribeReq, accessor: &'m Accessor<'m>, ) -> impl Iterator> + 'm where 's: 'm, { - if let Some(attr_requests) = attr_requests.as_ref() { - WildcardIter::Wildcard(attr_requests.iter().flat_map( - move |path| match self.expand_attr(accessor, path.to_gp(), false) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + self.read_attr_requests( + req.paths.iter().copied(), + req.filters.as_slice(), + req.fabric_filtered, + accessor, + Some(req.resume_path.unwrap()), + ) + } - WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { - let dataver_filter = dataver_filters - .as_ref() - .iter() - .flat_map(|array| array.iter()) - .find_map(|filter| { - (filter.path.endpoint == ep && filter.path.cluster == cl) - .then_some(filter.data_ver) - }); + fn read_attr_requests<'s, 'm, P, D>( + &'s self, + attr_requests: P, + dataver_filters: D, + fabric_filtered: bool, + accessor: &'m Accessor<'m>, + from: Option, + ) -> impl Iterator> + 'm + where + 's: 'm, + P: Iterator + 'm, + D: Iterable + Clone + 'm, + { + attr_requests.flat_map(move |path| { + if path.to_gp().is_wildcard() { + let dataver_filters = dataver_filters.clone(); + let from = from; - Ok(AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: fabric_filtered, - dataver: dataver_filter, - wildcard, - }) - })) + let iter = self + .match_attributes(path.endpoint, path.cluster, path.attr) + .skip_while(move |(ep, cl, attr)| { + !Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _) + }) + .filter(move |(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + false, + attr.access, + ) + .is_ok() + }) + .map(move |(ep, cl, attr)| { + let dataver = dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: true, + }) + }); + + WildcardIter::Wildcard(iter) + } else { + let ep = path.endpoint.unwrap(); + let cl = path.cluster.unwrap(); + let attr = path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, false) { + Ok(()) => { + let dataver = dataver_filters.iter().find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: false, + }) } - Err(err) => { - WildcardIter::Single(once(Err(AttrStatus::new(&path.to_gp(), err, 0)))) - } - }, - )) - } else { - WildcardIter::None - } + Err(err) => Err(AttrStatus::new(&path.to_gp(), err, 0)), + }; + + WildcardIter::Single(once(result)) + } + }) } pub fn write<'m>( @@ -163,34 +273,64 @@ impl<'a> Node<'a> { IMStatusCode::UnsupportedAttribute, 0, )))) - } else { - match self.expand_attr(accessor, attr_data.path.to_gp(), true) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + } else if attr_data.path.to_gp().is_wildcard() { + let iter = self + .match_attributes( + attr_data.path.endpoint, + attr_data.path.cluster, + attr_data.path.attr, + ) + .filter(move |(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + true, + attr.access, + ) + .is_ok() + }) + .map(move |(ep, cl, attr)| { + Ok(( + AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard: true, + }, + attr_data.data.unwrap_tlv().unwrap(), + )) + }); - WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { - Ok(( - AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: attr_data.path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: false, - dataver: attr_data.data_ver, - wildcard, - }, - attr_data.data.unwrap_tlv().unwrap(), - )) - })) - } - Err(err) => WildcardIter::Single(once(Err(AttrStatus::new( - &attr_data.path.to_gp(), - err, - 0, - )))), - } + WildcardIter::Wildcard(iter) + } else { + let ep = attr_data.path.endpoint.unwrap(); + let cl = attr_data.path.cluster.unwrap(); + let attr = attr_data.path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, true) { + Ok(()) => Ok(( + AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard: false, + }, + attr_data.data.unwrap_tlv().unwrap(), + )), + Err(err) => Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0)), + }; + + WildcardIter::Single(once(result)) } }) } @@ -200,136 +340,99 @@ impl<'a> Node<'a> { req: &'m InvReq, accessor: &'m Accessor<'m>, ) -> impl Iterator), CmdStatus>> + 'm { - if let Some(inv_requests) = req.inv_requests.as_ref() { - WildcardIter::Wildcard(inv_requests.iter().flat_map(move |cmd_data| { - match self.expand_cmd(accessor, cmd_data.path.path) { - Ok(iter) => { - let wildcard = matches!(iter, WildcardIter::Wildcard(_)); - - WildcardIter::Wildcard(iter.map(move |(ep, cl, cmd)| { + req.inv_requests + .iter() + .flat_map(|inv_requests| inv_requests.iter()) + .flat_map(move |cmd_data| { + if cmd_data.path.path.is_wildcard() { + let iter = self + .match_commands( + cmd_data.path.path.endpoint, + cmd_data.path.path.cluster, + cmd_data.path.path.leaf.map(|leaf| leaf as _), + ) + .filter(move |(ep, cl, cmd)| { + Cluster::check_cmd_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(*cmd)), + ) + .is_ok() + }) + .map(move |(ep, cl, cmd)| { Ok(( CmdDetails { node: self, - endpoint_id: ep, - cluster_id: cl, + endpoint_id: ep.id, + cluster_id: cl.id, cmd_id: cmd, - wildcard, + wildcard: true, }, cmd_data.data.unwrap_tlv().unwrap(), )) - })) - } - Err(err) => { - WildcardIter::Single(once(Err(CmdStatus::new(cmd_data.path, err, 0)))) - } + }); + + WildcardIter::Wildcard(iter) + } else { + let ep = cmd_data.path.path.endpoint.unwrap(); + let cl = cmd_data.path.path.cluster.unwrap(); + let cmd = cmd_data.path.path.leaf.unwrap(); + + let result = match self.check_command(accessor, ep, cl, cmd) { + Ok(()) => Ok(( + CmdDetails { + node: self, + endpoint_id: cmd_data.path.path.endpoint.unwrap(), + cluster_id: cmd_data.path.path.cluster.unwrap(), + cmd_id: cmd_data.path.path.leaf.unwrap(), + wildcard: false, + }, + cmd_data.data.unwrap_tlv().unwrap(), + )), + Err(err) => Err(CmdStatus::new(cmd_data.path, err, 0)), + }; + + WildcardIter::Single(once(result)) } - })) + }) + } + + fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { + if let Some(path) = path { + path.endpoint.map(|id| id == ep).unwrap_or(true) + && path.cluster.map(|id| id == cl).unwrap_or(true) + && path.leaf.map(|id| id == leaf).unwrap_or(true) } else { - WildcardIter::None + true } } - fn expand_attr<'m>( - &'m self, - accessor: &'m Accessor<'m>, - path: GenericPath, - write: bool, - ) -> Result< - WildcardIter< - impl Iterator + 'm, - (EndptId, ClusterId, AttrId), - >, - IMStatusCode, - > { - if path.is_wildcard() { - Ok(WildcardIter::Wildcard(self.match_attributes( - accessor, - path.endpoint, - path.cluster, - path.leaf.map(|leaf| leaf as u16), - write, - ))) - } else { - self.check_attribute( - accessor, - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap() as _, - write, - )?; - - Ok(WildcardIter::Single(once(( - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap() as _, - )))) - } - } - - fn expand_cmd<'m>( - &'m self, - accessor: &'m Accessor<'m>, - path: GenericPath, - ) -> Result< - WildcardIter< - impl Iterator + 'm, - (EndptId, ClusterId, CmdId), - >, - IMStatusCode, - > { - if path.is_wildcard() { - Ok(WildcardIter::Wildcard(self.match_commands( - accessor, - path.endpoint, - path.cluster, - path.leaf, - ))) - } else { - self.check_command( - accessor, - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap(), - )?; - - Ok(WildcardIter::Single(once(( - path.endpoint.unwrap(), - path.cluster.unwrap(), - path.leaf.unwrap(), - )))) - } - } - - fn match_attributes<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_attributes( + &self, ep: Option, cl: Option, attr: Option, - write: bool, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_endpoints(ep).flat_map(move |endpoint| { endpoint - .match_attributes(accessor, cl, attr, write) - .map(move |(cl, attr)| (endpoint.id, cl, attr)) + .match_attributes(cl, attr) + .map(move |(cl, attr)| (endpoint, cl, attr)) }) } - fn match_commands<'m>( - &'m self, - accessor: &'m Accessor<'m>, + pub fn match_commands( + &self, ep: Option, cl: Option, cmd: Option, - ) -> impl Iterator + 'm { + ) -> impl Iterator + '_ { self.match_endpoints(ep).flat_map(move |endpoint| { endpoint - .match_commands(accessor, cl, cmd) - .map(move |(cl, cmd)| (endpoint.id, cl, cmd)) + .match_commands(cl, cmd) + .map(move |(cl, cmd)| (endpoint, cl, cmd)) }) } - fn check_attribute( + pub fn check_attribute( &self, accessor: &Accessor, ep: EndptId, @@ -341,7 +444,7 @@ impl<'a> Node<'a> { .and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write)) } - fn check_command( + pub fn check_command( &self, accessor: &Accessor, ep: EndptId, @@ -352,13 +455,13 @@ impl<'a> Node<'a> { .and_then(|endpoint| endpoint.check_command(accessor, cl, cmd)) } - fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { + pub fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { self.endpoints .iter() .filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true)) } - fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { + pub fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { self.endpoints .iter() .find(|endpoint| endpoint.id == ep) diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 44131b9..ebcbb14 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -21,19 +21,24 @@ use super::{ noc::{self, NocCluster}, nw_commissioning::{self, NwCommCluster}, }, - system_model::access_control::{self, AccessControlCluster}, + system_model::{ + access_control::{self, AccessControlCluster}, + descriptor::{self, DescriptorCluster}, + }, }; pub type RootEndpointHandler<'a> = handler_chain_type!( - AccessControlCluster<'a>, - NocCluster<'a>, - AdminCommCluster<'a>, - NwCommCluster, + DescriptorCluster, + BasicInfoCluster<'a>, GenCommCluster, - BasicInfoCluster<'a> + NwCommCluster, + AdminCommCluster<'a>, + NocCluster<'a>, + AccessControlCluster<'a> ); -pub const CLUSTERS: [Cluster<'static>; 6] = [ +pub const CLUSTERS: [Cluster<'static>; 7] = [ + descriptor::CLUSTER, cluster_basic_information::CLUSTER, general_commissioning::CLUSTER, nw_commissioning::CLUSTER, @@ -77,32 +82,29 @@ pub fn wrap<'a>( EmptyHandler .chain( endpoint_id, - cluster_basic_information::CLUSTER.id, - BasicInfoCluster::new(basic_info, rand), + access_control::ID, + AccessControlCluster::new(acl, rand), ) .chain( endpoint_id, - general_commissioning::CLUSTER.id, - GenCommCluster::new(rand), - ) - .chain( - endpoint_id, - nw_commissioning::CLUSTER.id, - NwCommCluster::new(rand), - ) - .chain( - endpoint_id, - admin_commissioning::CLUSTER.id, - AdminCommCluster::new(pase, mdns_mgr, rand), - ) - .chain( - endpoint_id, - noc::CLUSTER.id, + noc::ID, NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), ) .chain( endpoint_id, - access_control::CLUSTER.id, - AccessControlCluster::new(acl, rand), + admin_commissioning::ID, + AdminCommCluster::new(pase, mdns_mgr, rand), ) + .chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand)) + .chain( + endpoint_id, + general_commissioning::ID, + GenCommCluster::new(rand), + ) + .chain( + endpoint_id, + cluster_basic_information::ID, + BasicInfoCluster::new(basic_info, rand), + ) + .chain(endpoint_id, descriptor::ID, DescriptorCluster::new(rand)) } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index aea37c7..0c007d1 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -215,7 +215,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::ArmFailsafeResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_setregulatoryconfig( @@ -238,7 +238,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::SetRegulatoryConfigResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_commissioningcomplete( @@ -272,7 +272,7 @@ impl GenCommCluster { encoder .with_command(RespCommands::CommissioningCompleteResp as _)? - .set(&cmd_data) + .set(cmd_data) } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 0258f3a..acaea50 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -398,7 +398,7 @@ impl<'a> NocCluster<'a> { encoder .with_command(RespCommands::NOCResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_updatefablabel( @@ -527,7 +527,7 @@ impl<'a> NocCluster<'a> { encoder .with_command(RespCommands::CertChainResp as _)? - .set(&cmd_data) + .set(cmd_data) } fn handle_command_csrrequest( diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 7afff7a..47ffe6e 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -52,8 +52,16 @@ impl NwCommCluster { } impl Handler for NwCommCluster { - fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { - Err(Error::AttributeNotFound) + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + Err(Error::AttributeNotFound) + } + } else { + Ok(()) + } } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index 3980a43..ffba5e6 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -20,7 +20,7 @@ use core::convert::TryInto; use strum::{EnumDiscriminants, FromRepr}; -use crate::acl::{AclEntry, AclMgr}; +use crate::acl::{self, AclEntry, AclMgr}; use crate::data_model::objects::*; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; @@ -116,9 +116,14 @@ impl<'a> AccessControlCluster<'a> { writer.complete() } - _ => { - error!("Attribute not yet supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Attributes::SubjectsPerEntry(codec) => { + codec.encode(writer, acl::SUBJECTS_PER_ENTRY as u16) + } + Attributes::TargetsPerEntry(codec) => { + codec.encode(writer, acl::TARGETS_PER_ENTRY as u16) + } + Attributes::EntriesPerFabric(codec) => { + codec.encode(writer, acl::ENTRIES_PER_FABRIC as u16) } } } @@ -365,7 +370,7 @@ mod tests { writebuf.as_slice() ); } - writebuf.reset(0); + writebuf.reset(); // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 { @@ -400,7 +405,7 @@ mod tests { writebuf.as_slice() ); } - writebuf.reset(0); + writebuf.reset(); // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 { diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 42c55fd..b7e2425 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -495,7 +495,11 @@ impl FabricMgr { } #[cfg(feature = "nightly")] - pub async fn load_async(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + pub async fn load_async( + &mut self, + mut psm: T, + mdns_mgr: &mut MdnsMgr<'_>, + ) -> Result<(), Error> where T: crate::persist::asynch::AsyncPsm, { diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 935740e..162d64c 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -21,14 +21,23 @@ use crate::{ data_model::core::DataHandler, error::*, tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{exchange::ExchangeCtx, packet::Packet, proto_ctx::ProtoCtx, session::Session}, + transport::{ + exchange::{Exchange, ExchangeCtx}, + packet::Packet, + proto_ctx::ProtoCtx, + session::Session, + }, }; use colored::Colorize; use log::{error, info}; use num; use num_derive::FromPrimitive; -use super::messages::msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}; +use super::messages::{ + ib::{AttrPath, DataVersionFilter}, + msg::{self, InvReq, ReadReq, StatusResp, SubscribeReq, TimedReq, WriteReq}, + GenericPath, +}; #[macro_export] macro_rules! cmd_enter { @@ -132,6 +141,14 @@ impl<'a, 'b> Transaction<'a, 'b> { } } + pub fn exch(&self) -> &Exchange { + self.ctx.exch + } + + pub fn exch_mut(&mut self) -> &mut Exchange { + self.ctx.exch + } + pub fn session(&self) -> &Session { self.ctx.sess.session() } @@ -182,17 +199,25 @@ impl<'a, 'b> Transaction<'a, 'b> { /* Interaction Model ID as per the Matter Spec */ const PROTO_ID_INTERACTION_MODEL: usize = 0x01; +const MAX_RESUME_PATHS: usize = 128; +const MAX_RESUME_DATAVER_FILTERS: usize = 128; + +// This is the amount of space we reserve for other things to be attached towards +// the end of long reads. +const LONG_READS_TLV_RESERVE_SIZE: usize = 24; + pub enum Interaction<'a> { Read(ReadReq<'a>), Write(WriteReq<'a>), Invoke(InvReq<'a>), Subscribe(SubscribeReq<'a>), - Status(StatusResp), Timed(TimedReq), + ResumeRead(ResumeReadReq), + ResumeSubscribe(ResumeSubscribeReq), } impl<'a> Interaction<'a> { - pub fn new(rx: &'a Packet) -> Result { + fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { let opcode: OpCode = num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; @@ -202,243 +227,67 @@ impl<'a> Interaction<'a> { print_tlv_list(rx_data); match opcode { - OpCode::ReadRequest => Ok(Self::Read(ReadReq::from_tlv(&get_root_node_struct( + OpCode::ReadRequest => Ok(Some(Self::Read(ReadReq::from_tlv(&get_root_node_struct( rx_data, - )?)?)), - OpCode::WriteRequest => Ok(Self::Write(WriteReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), - OpCode::InvokeRequest => Ok(Self::Invoke(InvReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), - OpCode::SubscribeRequest => Ok(Self::Subscribe(SubscribeReq::from_tlv( + )?)?))), + OpCode::WriteRequest => Ok(Some(Self::Write(WriteReq::from_tlv( &get_root_node_struct(rx_data)?, - )?)), - OpCode::StatusResponse => Ok(Self::Status(StatusResp::from_tlv( + )?))), + OpCode::InvokeRequest => Ok(Some(Self::Invoke(InvReq::from_tlv( &get_root_node_struct(rx_data)?, - )?)), - OpCode::TimedRequest => Ok(Self::Timed(TimedReq::from_tlv(&get_root_node_struct( - rx_data, - )?)?)), + )?))), + OpCode::SubscribeRequest => Ok(Some(Self::Subscribe(SubscribeReq::from_tlv( + &get_root_node_struct(rx_data)?, + )?))), + OpCode::StatusResponse => { + let resp = StatusResp::from_tlv(&get_root_node_struct(rx_data)?)?; + + if resp.status == IMStatusCode::Success { + if let Some(req) = transaction.exch_mut().take_suspended_read_req() { + Ok(Some(Self::ResumeRead(req))) + } else if let Some(req) = transaction.exch_mut().take_suspended_subscribe_req() + { + Ok(Some(Self::ResumeSubscribe(req))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } + OpCode::TimedRequest => Ok(Some(Self::Timed(TimedReq::from_tlv( + &get_root_node_struct(rx_data)?, + )?))), _ => { - error!("Opcode Not Handled: {:?}", opcode); + error!("Opcode not handled: {:?}", opcode); Err(Error::InvalidOpcode) } } } - pub fn initiate_tx( - &self, + pub fn initiate( + rx: &'a Packet, tx: &mut Packet, transaction: &mut Transaction, - ) -> Result { - let reply = match self { - Self::Read(request) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::ReportData as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - - if request.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - } - - false - } - Self::Write(_) => { - if transaction.has_timed_out() { - Self::create_status_response(tx, IMStatusCode::Timeout)?; - - transaction.complete(); - transaction.ctx.exch.close(); - - true - } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::WriteResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - + ) -> Result, Error> { + if let Some(interaction) = Self::new(rx, transaction)? { + let initiated = match &interaction { + Interaction::Read(req) => req.initiate(tx, transaction)?, + Interaction::Write(req) => req.initiate(tx, transaction)?, + Interaction::Invoke(req) => req.initiate(tx, transaction)?, + Interaction::Subscribe(req) => req.initiate(tx, transaction)?, + Interaction::Timed(req) => { + req.process(tx, transaction)?; false } - } - Self::Invoke(request) => { - if transaction.has_timed_out() { - Self::create_status_response(tx, IMStatusCode::Timeout)?; + Interaction::ResumeRead(req) => req.initiate(tx, transaction)?, + Interaction::ResumeSubscribe(req) => req.initiate(tx, transaction)?, + }; - transaction.complete(); - transaction.ctx.exch.close(); - - true - } else { - let timed_tx = transaction.get_timeout().map(|_| true); - let timed_request = request.timed_request.filter(|a| *a); - - // Either both should be None, or both should be Some(true) - if timed_tx != timed_request { - Self::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; - - true - } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::InvokeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - - // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool( - TagType::Context(msg::InvRespTag::SupressResponse as u8), - false, - )?; - - if request.inv_requests.is_some() { - tw.start_array(TagType::Context( - msg::InvRespTag::InvokeResponses as u8, - ))?; - } - - false - } - } - } - Self::Subscribe(request) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::ReportData as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - - if request.attr_requests.is_some() { - tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - } - - true - } - Self::Status(_) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::SubscribeResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.start_struct(TagType::Anonymous)?; - - true - } - Self::Timed(request) => { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - transaction.set_timeout(request.timeout.into()); - - let status = StatusResp { - status: IMStatusCode::Success, - }; - - status.to_tlv(&mut tw, TagType::Anonymous)?; - - true - } - }; - - Ok(!reply) - } - - pub fn complete_tx( - &self, - tx: &mut Packet, - transaction: &mut Transaction, - ) -> Result { - let reply = match self { - Self::Read(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - if request.attr_requests.is_some() { - tw.end_container()?; - } - - // Suppress response always true for read interaction - tw.bool( - TagType::Context(msg::ReportDataTag::SupressResponse as u8), - true, - )?; - - tw.end_container()?; - - transaction.complete(); - - true - } - Self::Write(request) => { - let suppress = request.supress_response.unwrap_or_default(); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.end_container()?; - tw.end_container()?; - - transaction.complete(); - - if suppress { - error!("Supress response is set, is this the expected handling?"); - false - } else { - true - } - } - Self::Invoke(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - if request.inv_requests.is_some() { - tw.end_container()?; - } - - tw.end_container()?; - - true - } - Self::Subscribe(request) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - if request.attr_requests.is_some() { - tw.end_container()?; - } - - tw.end_container()?; - - true - } - Self::Status(_) => { - let mut tw = TLVWriter::new(tx.get_writebuf()?); - - tw.end_container()?; - - true - } - Self::Timed(_) => false, - }; - - if reply { - info!("Sending response"); - print_tlv_list(tx.as_slice()); + Ok(initiated.then_some(interaction)) + } else { + Ok(None) } - - if transaction.is_terminate() { - transaction.ctx.exch.terminate(); - } else if transaction.is_complete() { - transaction.ctx.exch.close(); - } - - Ok(true) } fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { @@ -452,6 +301,414 @@ impl<'a> Interaction<'a> { } } +impl<'a> ReadReq<'a> { + fn suspend(self, resume_path: GenericPath) -> ResumeReadReq { + ResumeReadReq { + paths: self + .attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()) + .collect(), + filters: self + .dataver_filters + .iter() + .flat_map(|filters| filters.iter()) + .collect(), + fabric_filtered: self.fabric_filtered, + resume_path, + } + } + + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = Self::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + if self.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + Ok(true) + } + + pub fn complete( + self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + let mut tw = Self::restore_long_read_space(tx)?; + + if self.attr_requests.is_some() { + tw.end_container()?; + } + + let more_chunks = if let Some(resume_path) = resume_path { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + + transaction + .exch_mut() + .set_suspended_read_req(self.suspend(resume_path)); + true + } else { + false + }; + + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + !more_chunks, + )?; + + tw.end_container()?; + + if !more_chunks { + transaction.complete(); + } + + Ok(true) + } + + fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { + let wb = tx.get_writebuf()?; + wb.shrink(LONG_READS_TLV_RESERVE_SIZE)?; + + Ok(TLVWriter::new(wb)) + } + + fn restore_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result, Error> { + let wb = tx.get_writebuf()?; + wb.expand(LONG_READS_TLV_RESERVE_SIZE)?; + + Ok(TLVWriter::new(wb)) + } +} + +impl<'a> WriteReq<'a> { + fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + if transaction.has_timed_out() { + Interaction::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + Ok(false) + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::WriteResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; + + Ok(true) + } + } + + pub fn complete(self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + let suppress = self.supress_response.unwrap_or_default(); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.end_container()?; + tw.end_container()?; + + transaction.complete(); + + Ok(if suppress { + error!("Supress response is set, is this the expected handling?"); + false + } else { + true + }) + } +} + +impl<'a> InvReq<'a> { + fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { + if transaction.has_timed_out() { + Interaction::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + Ok(false) + } else { + let timed_tx = transaction.get_timeout().map(|_| true); + let timed_request = self.timed_request.filter(|a| *a); + + // Either both should be None, or both should be Some(true) + if timed_tx != timed_request { + Interaction::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + + Ok(false) + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::InvokeResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + // Suppress Response -> TODO: Need to revisit this for cases where we send a command back + tw.bool( + TagType::Context(msg::InvRespTag::SupressResponse as u8), + false, + )?; + + if self.inv_requests.is_some() { + tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; + } + + Ok(true) + } + } + } + + pub fn complete(self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + if self.inv_requests.is_some() { + tw.end_container()?; + } + + tw.end_container()?; + + Ok(true) + } +} + +impl TimedReq { + pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + transaction.set_timeout(self.timeout.into()); + + let status = StatusResp { + status: IMStatusCode::Success, + }; + + status.to_tlv(&mut tw, TagType::Anonymous)?; + + Ok(()) + } +} + +impl<'a> SubscribeReq<'a> { + fn suspend(&self, resume_path: Option) -> ResumeSubscribeReq { + ResumeSubscribeReq { + paths: self + .attr_requests + .iter() + .flat_map(|attr_requests| attr_requests.iter()) + .collect(), + filters: self + .dataver_filters + .iter() + .flat_map(|filters| filters.iter()) + .collect(), + fabric_filtered: self.fabric_filtered, + resume_path, + keep_subs: self.keep_subs, + min_int_floor: self.min_int_floor, + max_int_ceil: self.max_int_ceil, + } + } + + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = ReadReq::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + if self.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + Ok(true) + } + + pub fn complete( + self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + let mut tw = ReadReq::restore_long_read_space(tx)?; + + if self.attr_requests.is_some() { + tw.end_container()?; + } + + if resume_path.is_some() { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + } + + transaction + .exch_mut() + .set_suspended_subscribe_req(self.suspend(resume_path)); + + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + false, + )?; + + tw.end_container()?; + + Ok(true) + } +} + +pub struct ResumeReadReq { + pub paths: heapless::Vec, + pub filters: heapless::Vec, + pub fabric_filtered: bool, + pub resume_path: GenericPath, +} + +impl ResumeReadReq { + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = ReadReq::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + + Ok(true) + } + + pub fn complete( + mut self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + let mut tw = ReadReq::restore_long_read_space(tx)?; + + tw.end_container()?; + + let continue_interaction = if let Some(resume_path) = resume_path { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + + self.resume_path = resume_path; + transaction.exch_mut().set_suspended_read_req(self); + true + } else { + false + }; + + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + !continue_interaction, + )?; + + tw.end_container()?; + + if !continue_interaction { + transaction.complete(); + } + + Ok(true) + } +} + +pub struct ResumeSubscribeReq { + pub paths: heapless::Vec, + pub filters: heapless::Vec, + pub fabric_filtered: bool, + pub resume_path: Option, + pub keep_subs: bool, + pub min_int_floor: u16, + pub max_int_ceil: u16, +} + +impl ResumeSubscribeReq { + fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + + if self.resume_path.is_some() { + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = ReadReq::reserve_long_read_space(tx)?; + + tw.start_struct(TagType::Anonymous)?; + + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } else { + tx.set_proto_opcode(OpCode::SubscribeResponse as u8); + + // let mut tw = TLVWriter::new(tx.get_writebuf()?); + // tw.start_struct(TagType::Anonymous)?; + } + + Ok(true) + } + + pub fn complete( + mut self, + tx: &mut Packet, + transaction: &mut Transaction, + resume_path: Option, + ) -> Result { + if self.resume_path.is_none() && resume_path.is_some() { + panic!("Cannot resume subscribe"); + } + + if self.resume_path.is_some() { + // Completing a ReportData message + let mut tw = ReadReq::restore_long_read_space(tx)?; + + tw.end_container()?; + + if resume_path.is_some() { + tw.bool( + TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8), + true, + )?; + } + + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + false, + )?; + + tw.end_container()?; + + self.resume_path = resume_path; + transaction.exch_mut().set_suspended_subscribe_req(self); + } else { + // Completing a SubscribeResponse message + + // let mut tw = TLVWriter::new(tx.get_writebuf()?); + // tw.end_container()?; + + transaction.complete(); + } + + Ok(true) + } +} + pub trait InteractionHandler { fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error>; } @@ -472,15 +729,14 @@ where T: DataHandler, { pub fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { - let interaction = Interaction::new(ctx.rx)?; let mut transaction = Transaction::new(&mut ctx.exch_ctx); - let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { - self.0.handle(&interaction, ctx.tx, &mut transaction)?; - interaction.complete_tx(ctx.tx, &mut transaction)? - } else { - true - }; + let reply = + if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { + self.0.handle(interaction, ctx.tx, &mut transaction)? + } else { + true + }; Ok(reply.then_some(ctx.tx.as_slice())) } @@ -495,17 +751,14 @@ where &mut self, ctx: &'a mut ProtoCtx<'_, '_>, ) -> Result, Error> { - let interaction = Interaction::new(ctx.rx)?; let mut transaction = Transaction::new(&mut ctx.exch_ctx); - let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { - self.0 - .handle(&interaction, ctx.tx, &mut transaction) - .await?; - interaction.complete_tx(ctx.tx, &mut transaction)? - } else { - true - }; + let reply = + if let Some(interaction) = Interaction::initiate(ctx.rx, ctx.tx, &mut transaction)? { + self.0.handle(interaction, ctx.tx, &mut transaction).await? + } else { + true + }; Ok(reply.then_some(ctx.tx.as_slice())) } diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 80802ed..f5b9cb0 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -328,7 +328,7 @@ impl<'a> Case<'a> { tw.end_container()?; let key = KeyPair::new_from_public(initiator_noc_cert.get_pubkey())?; - key.verify_msg(write_buf.into_slice(), sign)?; + key.verify_msg(write_buf.as_slice(), sign)?; Ok(()) } @@ -508,7 +508,7 @@ impl<'a> Case<'a> { cipher_text, cipher_text.len() - TAG_LEN, )?; - Ok(write_buf.into_slice().len()) + Ok(write_buf.as_slice().len()) } fn get_sigma2_sign( @@ -531,7 +531,7 @@ impl<'a> Case<'a> { tw.str8(TagType::Context(4), peer_pub_key)?; tw.end_container()?; //println!("TBS is {:x?}", write_buf.as_borrow_slice()); - fabric.sign_msg(write_buf.into_slice(), signature) + fabric.sign_msg(write_buf.as_slice(), signature) } } diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index e668a8d..c28a5b2 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -22,6 +22,7 @@ use core::time::Duration; use log::{error, info, trace}; use crate::error::Error; +use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; use crate::secure_channel; use crate::secure_channel::case::CaseSession; use crate::utils::epoch::Epoch; @@ -68,6 +69,8 @@ enum State { pub enum DataOption { CaseSession(CaseSession), Time(Duration), + SuspendedReadReq(ResumeReadReq), + SuspendedSubscibeReq(ResumeSubscribeReq), #[default] None, } @@ -124,18 +127,14 @@ impl Exchange { self.role } - pub fn is_data_none(&self) -> bool { - matches!(self.data, DataOption::None) + pub fn clear_data(&mut self) { + self.data = DataOption::None; } pub fn set_case_session(&mut self, session: CaseSession) { self.data = DataOption::CaseSession(session); } - pub fn clear_data(&mut self) { - self.data = DataOption::None; - } - pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { if let DataOption::CaseSession(session) = &mut self.data { Some(session) @@ -154,6 +153,34 @@ impl Exchange { } } + pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) { + self.data = DataOption::SuspendedReadReq(req); + } + + pub fn take_suspended_read_req(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::SuspendedReadReq(req) = old { + Some(req) + } else { + self.data = old; + None + } + } + + pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) { + self.data = DataOption::SuspendedSubscibeReq(req); + } + + pub fn take_suspended_subscribe_req(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::SuspendedSubscibeReq(req) = old { + Some(req) + } else { + self.data = old; + None + } + } + pub fn set_data_time(&mut self, expiry_ts: Option) { if let Some(t) = expiry_ts { self.data = DataOption::Time(t); @@ -430,7 +457,7 @@ mod tests { error::Error, transport::{ network::Address, - packet::Packet, + packet::{Packet, MAX_TX_BUF_SIZE}, session::{CloneData, SessionMode, MAX_SESSIONS}, }, utils::{ @@ -505,7 +532,7 @@ mod tests { /// - The sessions are evicted in LRU /// - The exchanges associated with those sessions are evicted too fn test_sess_evict() { - let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); // TODO + let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); fill_sessions(&mut mgr, MAX_SESSIONS + 1); // Sessions are now full from local session id 1 to 16 @@ -531,7 +558,7 @@ mod tests { let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); assert!(matches!(result, Err(Error::NoSpace))); - let mut buf = [0; 1500]; + let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); let evicted = mgr.evict_session(tx).unwrap(); assert!(evicted); diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index e39ac1c..b2ca7aa 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -33,6 +33,7 @@ use super::{ }; pub const MAX_RX_BUF_SIZE: usize = 1583; +pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; type Buffer = [u8; MAX_RX_BUF_SIZE]; // TODO: I am not very happy with this construction, need to find another way to do this diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index 96928ac..fd392bd 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -311,7 +311,7 @@ mod tests { encrypt_in_place(send_ctr, 0, &plain_hdr, &mut writebuf, &key).unwrap(); assert_eq!( - writebuf.into_slice(), + writebuf.as_slice(), [ 189, 83, 250, 121, 38, 87, 97, 17, 153, 78, 243, 20, 36, 11, 131, 142, 136, 165, 227, 107, 204, 129, 193, 153, 42, 131, 138, 254, 22, 190, 76, 244, 116, 45, 156, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index d4c4985..95597e2 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -267,7 +267,7 @@ impl Session { let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); tx.proto.encode(&mut write_buf)?; - tx.get_writebuf()?.prepend(write_buf.into_slice())?; + tx.get_writebuf()?.prepend(write_buf.as_slice())?; // Generate plain-text header if self.mode == SessionMode::PlainText { @@ -278,7 +278,7 @@ impl Session { let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; let mut write_buf = WriteBuf::new(&mut tmp_buf); tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.into_slice(); + let plain_hdr_bytes = write_buf.as_slice(); trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); let ctr = tx.plain.ctr; diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index fae4481..3adafe2 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -18,44 +18,21 @@ use crate::error::*; use byteorder::{ByteOrder, LittleEndian}; -/// Shrink WriteBuf -/// -/// This Macro creates a new (child) WriteBuf which has a truncated slice end. -/// - It accepts a WriteBuf, and the size to reserve (truncate) towards the end. -/// - It returns the new (child) WriteBuf -#[macro_export] -macro_rules! wb_shrink { - ($orig_wb:ident, $reserve:ident) => {{ - let m_data = $orig_wb.empty_as_mut_slice(); - let m_wb = WriteBuf::new(m_data, m_data.len() - $reserve); - (m_wb) - }}; -} - -/// Unshrink WriteBuf -/// -/// This macro unshrinks the WriteBuf -/// - It accepts the original WriteBuf and the child WriteBuf (that was the result of wb_shrink) -/// After this call, the child WriteBuf shouldn't be used -#[macro_export] -macro_rules! wb_unshrink { - ($orig_wb:ident, $new_wb:ident) => {{ - let m_data_len = $new_wb.as_slice().len(); - $orig_wb.forward_tail_by(m_data_len); - }}; -} - #[derive(Debug)] pub struct WriteBuf<'a> { buf: &'a mut [u8], + buf_size: usize, start: usize, end: usize, } impl<'a> WriteBuf<'a> { pub fn new(buf: &'a mut [u8]) -> Self { + let buf_size = buf.len(); + Self { buf, + buf_size, start: 0, end: 0, } @@ -73,10 +50,6 @@ impl<'a> WriteBuf<'a> { self.end += new_offset } - pub fn into_slice(self) -> &'a [u8] { - &self.buf[self.start..self.end] - } - pub fn as_slice(&self) -> &[u8] { &self.buf[self.start..self.end] } @@ -86,20 +59,43 @@ impl<'a> WriteBuf<'a> { } pub fn empty_as_mut_slice(&mut self) -> &mut [u8] { - &mut self.buf[self.end..] + &mut self.buf[self.end..self.buf_size] } - pub fn reset(&mut self, reserve: usize) { - self.start = reserve; - self.end = reserve; + pub fn reset(&mut self) { + self.buf_size = self.buf.len(); + self.start = 0; + self.end = 0; } pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { - if self.end != 0 || self.start != 0 { - return Err(Error::Invalid); + if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { + Err(Error::Invalid) + } else if reserve > self.buf_size { + Err(Error::NoSpace) + } else { + self.start = reserve; + self.end = reserve; + Ok(()) + } + } + + pub fn shrink(&mut self, with: usize) -> Result<(), Error> { + if self.end + with <= self.buf_size { + self.buf_size -= with; + Ok(()) + } else { + Err(Error::NoSpace) + } + } + + pub fn expand(&mut self, by: usize) -> Result<(), Error> { + if self.buf.len() - self.buf_size >= by { + self.buf_size += by; + Ok(()) + } else { + Err(Error::NoSpace) } - self.reset(reserve); - Ok(()) } pub fn prepend_with(&mut self, size: usize, f: F) -> Result<(), Error> @@ -125,7 +121,7 @@ impl<'a> WriteBuf<'a> { where F: FnOnce(&mut Self), { - if self.end + size <= self.buf.len() { + if self.end + size <= self.buf_size { f(self); self.end += size; return Ok(()); @@ -274,7 +270,7 @@ mod tests { buf.prepend(&new_slice).unwrap(); assert_eq!( - buf.into_slice(), + buf.as_slice(), [ 0xa, 0xb, 0xc, 1, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 348ce74..116ad50 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -24,16 +24,20 @@ use matter::{ cluster_on_off::{self, OnOffCluster}, core::DataModel, device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, - objects::{ChainedHandler, Endpoint, Node, Privilege}, + objects::{Endpoint, Node, Privilege}, root_endpoint::{self, RootEndpointHandler}, sdm::{ admin_commissioning, dev_att::{DataType, DevAttDataFetcher}, general_commissioning, noc, nw_commissioning, }, - system_model::access_control, + system_model::{ + access_control, + descriptor::{self, DescriptorCluster}, + }, }, error::Error, + handler_chain_type, interaction_model::core::{InteractionModel, OpCode}, mdns::Mdns, tlv::{TLVWriter, TagType, ToTLV}, @@ -41,6 +45,7 @@ use matter::{ transport::{ exchange::{self, Exchange, ExchangeCtx}, network::Address, + packet::MAX_RX_BUF_SIZE, proto_ctx::ProtoCtx, session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, @@ -97,12 +102,9 @@ impl<'a> ImInput<'a> { } } -pub type DmHandler<'a> = ChainedHandler< - OnOffCluster, - ChainedHandler>>, ->; +pub type DmHandler<'a> = handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster, EchoCluster | RootEndpointHandler<'a>); -pub fn matter<'a>(mdns: &'a mut dyn Mdns) -> Matter<'_> { +pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand) } @@ -132,6 +134,7 @@ impl<'a> ImEngine<'a> { Endpoint { id: 0, clusters: &[ + descriptor::CLUSTER, cluster_basic_information::CLUSTER, general_commissioning::CLUSTER, nw_commissioning::CLUSTER, @@ -144,13 +147,18 @@ impl<'a> ImEngine<'a> { }, Endpoint { id: 1, - clusters: &[echo_cluster::CLUSTER, cluster_on_off::CLUSTER], + clusters: &[ + descriptor::CLUSTER, + cluster_on_off::CLUSTER, + echo_cluster::CLUSTER, + ], device_type: DEV_TYPE_ON_OFF_LIGHT, }, ], }, root_endpoint::handler(0, &DummyDevAtt {}, matter) .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, descriptor::ID, DescriptorCluster::new(*matter.borrow())) .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), ); @@ -164,7 +172,7 @@ impl<'a> ImEngine<'a> { pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { match endpoint { - 0 => &self.im.0.handler.next.next.handler, + 0 => &self.im.0.handler.next.next.next.handler, 1 => &self.im.0.handler.next.handler, _ => panic!(), } @@ -196,8 +204,8 @@ impl<'a> ImEngine<'a> { sess, epoch: *self.matter.borrow(), }; - let mut tx_buf = [0; 1500]; - let mut rx_buf = [0; 1500]; + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; 1450]; // For the long read tests to run unchanged let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter/tests/data_model/long_reads.rs b/matter/tests/data_model/long_reads.rs index 9f7957a..693f1df 100644 --- a/matter/tests/data_model/long_reads.rs +++ b/matter/tests/data_model/long_reads.rs @@ -30,11 +30,13 @@ use matter::{ }, messages::{msg::SubscribeReq, GenericPath}, }, + mdns::DummyMdns, tlv::{self, ElementType, FromTLV, TLVElement, TagType, ToTLV}, transport::{ exchange::{self, Exchange}, udp::MAX_RX_BUF_SIZE, }, + Matter, }; use crate::{ @@ -42,28 +44,28 @@ use crate::{ common::{ attributes::*, echo_cluster as echo, - im_engine::{ImEngine, ImInput}, + im_engine::{matter, ImEngine, ImInput}, }, }; -pub struct LongRead { - im_engine: ImEngine, +pub struct LongRead<'a> { + im_engine: ImEngine<'a>, } -impl LongRead { - pub fn new() -> Self { - let mut im_engine = ImEngine::new(); +impl<'a> LongRead<'a> { + pub fn new(matter: &'a Matter<'a>) -> Self { + let mut im_engine = ImEngine::new(matter); // Use the same exchange for all parts of the transaction im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); Self { im_engine } } - pub fn process<'a>( + pub fn process<'p>( &mut self, action: OpCode, data: &dyn ToTLV, - data_out: &'a mut [u8], - ) -> (u8, &'a mut [u8]) { + data_out: &'p mut [u8], + ) -> (u8, &'p [u8]) { let input = ImInput::new(action, data); let (response, output) = self.im_engine.process(&input, data_out); (response, output) @@ -82,49 +84,139 @@ fn wildcard_read_resp(part: u8) -> Vec> { attr_data!(0, 29, descriptor::Attributes::ClientList, dont_care), attr_data!(0, 40, GlobalElements::FeatureMap, dont_care), attr_data!(0, 40, GlobalElements::AttributeList, dont_care), - attr_data!(0, 40, basic_info::Attributes::DMRevision, dont_care), - attr_data!(0, 40, basic_info::Attributes::VendorId, dont_care), - attr_data!(0, 40, basic_info::Attributes::ProductId, dont_care), - attr_data!(0, 40, basic_info::Attributes::HwVer, dont_care), - attr_data!(0, 40, basic_info::Attributes::SwVer, dont_care), - attr_data!(0, 40, basic_info::Attributes::SwVerString, dont_care), - attr_data!(0, 40, basic_info::Attributes::SerialNo, dont_care), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::DMRevision, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::VendorId, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::ProductId, + dont_care + ), + attr_data!(0, 40, basic_info::AttributesDiscriminants::HwVer, dont_care), + attr_data!(0, 40, basic_info::AttributesDiscriminants::SwVer, dont_care), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SwVerString, + dont_care + ), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SerialNo, + dont_care + ), attr_data!(0, 48, GlobalElements::FeatureMap, dont_care), attr_data!(0, 48, GlobalElements::AttributeList, dont_care), - attr_data!(0, 48, gen_comm::Attributes::BreadCrumb, dont_care), - attr_data!(0, 48, gen_comm::Attributes::RegConfig, dont_care), - attr_data!(0, 48, gen_comm::Attributes::LocationCapability, dont_care), attr_data!( 0, 48, - gen_comm::Attributes::BasicCommissioningInfo, + gen_comm::AttributesDiscriminants::BreadCrumb, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::RegConfig, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::LocationCapability, + dont_care + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::BasicCommissioningInfo, dont_care ), attr_data!(0, 49, GlobalElements::FeatureMap, dont_care), attr_data!(0, 49, GlobalElements::AttributeList, dont_care), attr_data!(0, 60, GlobalElements::FeatureMap, dont_care), attr_data!(0, 60, GlobalElements::AttributeList, dont_care), - attr_data!(0, 60, adm_comm::Attributes::WindowStatus, dont_care), - attr_data!(0, 60, adm_comm::Attributes::AdminFabricIndex, dont_care), - attr_data!(0, 60, adm_comm::Attributes::AdminVendorId, dont_care), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::WindowStatus, + dont_care + ), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminFabricIndex, + dont_care + ), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminVendorId, + dont_care + ), attr_data!(0, 62, GlobalElements::FeatureMap, dont_care), attr_data!(0, 62, GlobalElements::AttributeList, dont_care), - attr_data!(0, 62, noc::Attributes::CurrentFabricIndex, dont_care), - attr_data!(0, 62, noc::Attributes::Fabrics, dont_care), - attr_data!(0, 62, noc::Attributes::SupportedFabrics, dont_care), - attr_data!(0, 62, noc::Attributes::CommissionedFabrics, dont_care), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CurrentFabricIndex, + dont_care + ), + attr_data!(0, 62, noc::AttributesDiscriminants::Fabrics, dont_care), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::SupportedFabrics, + dont_care + ), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CommissionedFabrics, + dont_care + ), attr_data!(0, 31, GlobalElements::FeatureMap, dont_care), attr_data!(0, 31, GlobalElements::AttributeList, dont_care), - attr_data!(0, 31, acl::Attributes::Acl, dont_care), - attr_data!(0, 31, acl::Attributes::Extension, dont_care), - attr_data!(0, 31, acl::Attributes::SubjectsPerEntry, dont_care), - attr_data!(0, 31, acl::Attributes::TargetsPerEntry, dont_care), - attr_data!(0, 31, acl::Attributes::EntriesPerFabric, dont_care), + attr_data!(0, 31, acl::AttributesDiscriminants::Acl, dont_care), + attr_data!(0, 31, acl::AttributesDiscriminants::Extension, dont_care), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::SubjectsPerEntry, + dont_care + ), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::TargetsPerEntry, + dont_care + ), + attr_data!( + 0, + 31, + acl::AttributesDiscriminants::EntriesPerFabric, + dont_care + ), attr_data!(0, echo::ID, GlobalElements::FeatureMap, dont_care), attr_data!(0, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(0, echo::ID, echo::Attributes::Att1, dont_care), - attr_data!(0, echo::ID, echo::Attributes::Att2, dont_care), - attr_data!(0, echo::ID, echo::Attributes::AttCustom, dont_care), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), + attr_data!( + 0, + echo::ID, + echo::AttributesDiscriminants::AttCustom, + dont_care + ), attr_data!(1, 29, GlobalElements::FeatureMap, dont_care), attr_data!(1, 29, GlobalElements::AttributeList, dont_care), attr_data!(1, 29, descriptor::Attributes::DeviceTypeList, dont_care), @@ -136,12 +228,17 @@ fn wildcard_read_resp(part: u8) -> Vec> { attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care), attr_data!(1, 6, GlobalElements::FeatureMap, dont_care), attr_data!(1, 6, GlobalElements::AttributeList, dont_care), - attr_data!(1, 6, onoff::Attributes::OnOff, dont_care), + attr_data!(1, 6, onoff::AttributesDiscriminants::OnOff, dont_care), attr_data!(1, echo::ID, GlobalElements::FeatureMap, dont_care), attr_data!(1, echo::ID, GlobalElements::AttributeList, dont_care), - attr_data!(1, echo::ID, echo::Attributes::Att1, dont_care), - attr_data!(1, echo::ID, echo::Attributes::Att2, dont_care), - attr_data!(1, echo::ID, echo::Attributes::AttCustom, dont_care), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att1, dont_care), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att2, dont_care), + attr_data!( + 1, + echo::ID, + echo::AttributesDiscriminants::AttCustom, + dont_care + ), ]; if part == 1 { @@ -155,7 +252,9 @@ fn wildcard_read_resp(part: u8) -> Vec> { fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete let _ = env_logger::try_init(); - let mut lr = LongRead::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut lr = LongRead::new(&matter); let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let wc_path = GenericPath::new(None, None, None); @@ -187,7 +286,9 @@ fn test_long_read_success() { fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete let _ = env_logger::try_init(); - let mut lr = LongRead::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut lr = LongRead::new(&matter); let mut output = [0_u8; MAX_RX_BUF_SIZE + 100]; let wc_path = GenericPath::new(None, None, None); @@ -219,6 +320,6 @@ fn test_long_read_subscription_success() { tlv::print_tlv_list(out_data); let root = tlv::get_root_node_struct(out_data).unwrap(); let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); - assert_eq!(out_code, OpCode::SubscriptResponse as u8); + assert_eq!(out_code, OpCode::SubscribeResponse as u8); assert_eq!(subs_resp.subs_id, 1); } diff --git a/matter/tests/data_model_tests.rs b/matter/tests/data_model_tests.rs index 803c4c5..392909f 100644 --- a/matter/tests/data_model_tests.rs +++ b/matter/tests/data_model_tests.rs @@ -22,6 +22,6 @@ mod data_model { mod attribute_lists; mod attributes; mod commands; - // TODO mod long_reads; + mod long_reads; mod timed_requests; } diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index 07d114e..b73ab46 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -25,6 +25,8 @@ use matter::transport::exchange::Exchange; use matter::transport::exchange::ExchangeCtx; use matter::transport::network::Address; use matter::transport::packet::Packet; +use matter::transport::packet::MAX_RX_BUF_SIZE; +use matter::transport::packet::MAX_TX_BUF_SIZE; use matter::transport::proto_ctx::ProtoCtx; use matter::transport::session::SessionMgr; use matter::utils::epoch::dummy_epoch; @@ -52,30 +54,27 @@ impl DataModel { impl DataHandler for DataModel { fn handle( &mut self, - interaction: &Interaction, + interaction: Interaction, _tx: &mut Packet, _transaction: &mut Transaction, ) -> Result { - match interaction { - Interaction::Invoke(req) => { - if let Some(inv_requests) = &req.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let mut common_data = &mut self.node; - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); - } + if let Interaction::Invoke(req) = interaction { + if let Some(inv_requests) = &req.inv_requests { + for i in inv_requests.iter() { + let data = if let Some(data) = i.data.unwrap_tlv() { + data + } else { + continue; + }; + let cmd_path_ib = i.path; + let mut common_data = &mut self.node; + common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); + common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); + common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; + data.confirm_struct().unwrap(); + common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); } } - _ => (), } Ok(false) @@ -109,8 +108,8 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode sess, epoch: dummy_epoch, }; - let mut rx_buf = [0; 1500]; - let mut tx_buf = [0; 1500]; + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; MAX_TX_BUF_SIZE]; let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter_macro_derive/src/lib.rs b/matter_macro_derive/src/lib.rs index 0fc358f..a1fc553 100644 --- a/matter_macro_derive/src/lib.rs +++ b/matter_macro_derive/src/lib.rs @@ -138,11 +138,20 @@ fn gen_totlv_for_struct( let expanded = quote! { impl #generics ToTLV for #struct_name #generics { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw. #datatype (tag_type)?; - #( - self.#idents.to_tlv(tw, TagType::Context(#tags))?; - )* - tw.end_container() + let anchor = tw.get_tail(); + + if let Err(err) = (|| { + tw. #datatype (tag_type)?; + #( + self.#idents.to_tlv(tw, TagType::Context(#tags))?; + )* + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } } } }; @@ -179,17 +188,26 @@ fn gen_totlv_for_enum( } let expanded = quote! { - impl #generics ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.start_struct(tag_type)?; - match self { - #( - Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, - )* - } - tw.end_container() - } - } + impl #generics ToTLV for #enum_name #generics { + fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { + let anchor = tw.get_tail(); + + if let Err(err) = (|| { + tw.start_struct(tag_type)?; + match self { + #( + Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, + )* + } + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + } }; // panic!("Expanded to {}", expanded);