diff --git a/examples/onoff_light/src/lib.rs b/examples/onoff_light/src/lib.rs index 43ca1b1..16264d0 100644 --- a/examples/onoff_light/src/lib.rs +++ b/examples/onoff_light/src/lib.rs @@ -15,4 +15,4 @@ * limitations under the License. */ -pub mod dev_att; +// TODO pub mod dev_att; diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1eb5d63..b2bc448 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -15,40 +15,41 @@ * limitations under the License. */ -mod dev_att; -use matter::core::{self, CommissioningData}; -use matter::data_model::cluster_basic_information::BasicInfoConfig; -use matter::data_model::device_types::device_type_add_on_off_light; -use matter::secure_channel::spake2p::VerifierData; +// TODO +// mod dev_att; +// use matter::core::{self, CommissioningData}; +// use matter::data_model::cluster_basic_information::BasicInfoConfig; +// use matter::data_model::device_types::device_type_add_on_off_light; +// use matter::secure_channel::spake2p::VerifierData; fn main() { - env_logger::init(); - let comm_data = CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456), - discriminator: 250, - }; + // env_logger::init(); + // let comm_data = CommissioningData { + // // TODO: Hard-coded for now + // verifier: VerifierData::new_with_pw(123456), + // discriminator: 250, + // }; - // vid/pid should match those in the DAC - let dev_info = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "OnOff Light".to_string(), - }; - let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + // // vid/pid should match those in the DAC + // let dev_info = BasicInfoConfig { + // vid: 0xFFF1, + // pid: 0x8000, + // hw_ver: 2, + // sw_ver: 1, + // sw_ver_str: "1".to_string(), + // serial_no: "aabbccdd".to_string(), + // device_name: "OnOff Light".to_string(), + // }; + // let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); - let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); - let dm = matter.get_data_model(); - { - let mut node = dm.node.write().unwrap(); - let endpoint = device_type_add_on_off_light(&mut node).unwrap(); - println!("Added OnOff Light Device type at endpoint id: {}", endpoint); - println!("Data Model now is: {}", node); - } + // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + // let dm = matter.get_data_model(); + // { + // let mut node = dm.node.write().unwrap(); + // let endpoint = device_type_add_on_off_light(&mut node).unwrap(); + // println!("Added OnOff Light Device type at endpoint id: {}", endpoint); + // println!("Data Model now is: {}", node); + // } - matter.start_daemon().unwrap(); + // matter.start_daemon().unwrap(); } diff --git a/examples/speaker/src/lib.rs b/examples/speaker/src/lib.rs index 43ca1b1..16264d0 100644 --- a/examples/speaker/src/lib.rs +++ b/examples/speaker/src/lib.rs @@ -15,4 +15,4 @@ * limitations under the License. */ -pub mod dev_att; +// TODO pub mod dev_att; diff --git a/examples/speaker/src/main.rs b/examples/speaker/src/main.rs index de2a605..f3b3f7d 100644 --- a/examples/speaker/src/main.rs +++ b/examples/speaker/src/main.rs @@ -15,55 +15,56 @@ * limitations under the License. */ -mod dev_att; -use matter::core::{self, CommissioningData}; -use matter::data_model::cluster_basic_information::BasicInfoConfig; -use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; -use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; -use matter::secure_channel::spake2p::VerifierData; +// TODO +// mod dev_att; +// use matter::core::{self, CommissioningData}; +// use matter::data_model::cluster_basic_information::BasicInfoConfig; +// use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; +// use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; +// use matter::secure_channel::spake2p::VerifierData; fn main() { - env_logger::init(); - let comm_data = CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456), - discriminator: 250, - }; + // env_logger::init(); + // let comm_data = CommissioningData { + // // TODO: Hard-coded for now + // verifier: VerifierData::new_with_pw(123456), + // discriminator: 250, + // }; - // vid/pid should match those in the DAC - let dev_info = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8002, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "Smart Speaker".to_string(), - }; - let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + // // vid/pid should match those in the DAC + // let dev_info = BasicInfoConfig { + // vid: 0xFFF1, + // pid: 0x8002, + // hw_ver: 2, + // sw_ver: 1, + // sw_ver_str: "1".to_string(), + // serial_no: "aabbccdd".to_string(), + // device_name: "Smart Speaker".to_string(), + // }; + // let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); - let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); - let dm = matter.get_data_model(); - { - let mut node = dm.node.write().unwrap(); + // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + // let dm = matter.get_data_model(); + // { + // let mut node = dm.node.write().unwrap(); - let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); - let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); + // let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); + // let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); - // Add some callbacks - let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); - let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); - let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); - let start_over_callback = - Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); - media_playback_cluster.add_callback(Commands::Play, play_callback); - media_playback_cluster.add_callback(Commands::Pause, pause_callback); - media_playback_cluster.add_callback(Commands::Stop, stop_callback); - media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); + // // Add some callbacks + // let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); + // let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); + // let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); + // let start_over_callback = + // Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); + // media_playback_cluster.add_callback(Commands::Play, play_callback); + // media_playback_cluster.add_callback(Commands::Pause, pause_callback); + // media_playback_cluster.add_callback(Commands::Stop, stop_callback); + // media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); - node.add_cluster(endpoint_audio, media_playback_cluster) - .unwrap(); - println!("Added Speaker type at endpoint id: {}", endpoint_audio) - } - matter.start_daemon().unwrap(); + // node.add_cluster(endpoint_audio, media_playback_cluster) + // .unwrap(); + // println!("Added Speaker type at endpoint id: {}", endpoint_audio) + // } + // matter.start_daemon().unwrap(); } diff --git a/examples/speaker/src/speaker.rs b/examples/speaker/src/speaker.rs new file mode 100644 index 0000000..de2a605 --- /dev/null +++ b/examples/speaker/src/speaker.rs @@ -0,0 +1,69 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +mod dev_att; +use matter::core::{self, CommissioningData}; +use matter::data_model::cluster_basic_information::BasicInfoConfig; +use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; +use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; +use matter::secure_channel::spake2p::VerifierData; + +fn main() { + env_logger::init(); + let comm_data = CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456), + discriminator: 250, + }; + + // vid/pid should match those in the DAC + let dev_info = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8002, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1".to_string(), + serial_no: "aabbccdd".to_string(), + device_name: "Smart Speaker".to_string(), + }; + let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); + + let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); + let dm = matter.get_data_model(); + { + let mut node = dm.node.write().unwrap(); + + let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); + let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); + + // Add some callbacks + let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); + let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); + let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); + let start_over_callback = + Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); + media_playback_cluster.add_callback(Commands::Play, play_callback); + media_playback_cluster.add_callback(Commands::Pause, pause_callback); + media_playback_cluster.add_callback(Commands::Stop, stop_callback); + media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); + + node.add_cluster(endpoint_audio, media_playback_cluster) + .unwrap(); + println!("Added Speaker type at endpoint id: {}", endpoint_audio) + } + matter.start_daemon().unwrap(); +} diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 9de56a4..2769010 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,7 +15,9 @@ name = "matter" path = "src/lib.rs" [features] -default = ["crypto_mbedtls"] +default = ["std", "crypto_mbedtls"] +std = [] +nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] crypto_mbedtls = ["mbedtls"] crypto_esp_mbedtls = ["esp-idf-sys"] @@ -34,7 +36,7 @@ num-traits = "0.2.15" log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } env_logger = { version = "0.10.0", default-features = false, features = [] } rand = "0.8.5" -esp-idf-sys = { version = "0.32", features = ["binstart"], optional = true } +esp-idf-sys = { version = "0.32", optional = true } subtle = "2.4.1" colored = "2.0.0" smol = "1.3.0" @@ -42,6 +44,7 @@ owning_ref = "0.4.1" safemem = "0.3.3" chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } async-channel = "1.8" +strum = { version = "0.24", features = ["derive"], no-default-feature = true } # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 708ddee..8f965e1 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -15,19 +15,16 @@ * limitations under the License. */ -use std::{ - fmt::Display, - sync::{Arc, Mutex, MutexGuard, RwLock}, -}; +use core::{cell::RefCell, fmt::Display}; use crate::{ data_model::objects::{Access, ClusterId, EndptId, Privilege}, error::Error, fabric, interaction_model::messages::GenericPath, - sys::Psm, + persist::Psm, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - transport::session::MAX_CAT_IDS_PER_NOC, + transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, }; use log::error; @@ -160,7 +157,7 @@ impl Display for AccessorSubjects { } /// The Accessor Object -pub struct Accessor { +pub struct Accessor<'a> { /// The fabric index of the accessor pub fab_idx: u8, /// Accessor's subject: could be node-id, NoC CAT, group id @@ -168,15 +165,37 @@ pub struct Accessor { /// The Authmode of this session auth_mode: AuthMode, // TODO: Is this the right place for this though, or should we just use a global-acl-handle-get - acl_mgr: Arc, + acl_mgr: &'a RefCell, } -impl Accessor { - pub fn new( +impl<'a> Accessor<'a> { + pub fn for_session(session: &Session, acl_mgr: &'a RefCell) -> Self { + match session.get_session_mode() { + SessionMode::Case(c) => { + let mut subject = + AccessorSubjects::new(session.get_peer_node_id().unwrap_or_default()); + for i in c.cat_ids { + if i != 0 { + let _ = subject.add_catid(i); + } + } + Accessor::new(c.fab_idx, subject, AuthMode::Case, &acl_mgr) + } + SessionMode::Pase => { + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, &acl_mgr) + } + + SessionMode::PlainText => { + Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, &acl_mgr) + } + } + } + + pub const fn new( fab_idx: u8, subjects: AccessorSubjects, auth_mode: AuthMode, - acl_mgr: Arc, + acl_mgr: &'a RefCell, ) -> Self { Self { fab_idx, @@ -188,9 +207,9 @@ impl Accessor { } #[derive(Debug)] -pub struct AccessDesc<'a> { +pub struct AccessDesc { /// The object to be acted upon - path: &'a GenericPath, + path: GenericPath, /// The target permissions target_perms: Option, // The operation being done @@ -200,8 +219,8 @@ pub struct AccessDesc<'a> { /// Access Request Object pub struct AccessReq<'a> { - accessor: &'a Accessor, - object: AccessDesc<'a>, + accessor: &'a Accessor<'a>, + object: AccessDesc, } impl<'a> AccessReq<'a> { @@ -209,7 +228,7 @@ impl<'a> AccessReq<'a> { /// /// An access request specifies the _accessor_ attempting to access _path_ /// with _operation_ - pub fn new(accessor: &'a Accessor, path: &'a GenericPath, operation: Access) -> Self { + pub fn new(accessor: &'a Accessor, path: GenericPath, operation: Access) -> Self { AccessReq { accessor, object: AccessDesc { @@ -220,6 +239,10 @@ impl<'a> AccessReq<'a> { } } + pub fn operation(&self) -> Access { + self.object.operation + } + /// Add target's permissions to the request /// /// The permissions that are associated with the target (identified by the @@ -234,7 +257,7 @@ impl<'a> AccessReq<'a> { /// _accessor_ the necessary privileges to access the target as per its /// permissions pub fn allow(&self) -> bool { - self.accessor.acl_mgr.allow(self) + self.accessor.acl_mgr.borrow().allow(self) } } @@ -369,33 +392,184 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; type AclEntries = [Option; MAX_ACL_ENTRIES]; -#[derive(ToTLV, FromTLV, Debug)] -struct AclMgrInner { - entries: AclEntries, -} - const ACL_KV_ENTRY: &str = "acl"; const ACL_KV_MAX_SIZE: usize = 300; -impl AclMgrInner { - pub fn store(&self, psm: &MutexGuard) -> Result<(), Error> { - let mut acl_tlvs = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut acl_tlvs, ACL_KV_MAX_SIZE); - let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()) + +pub struct AclMgr { + entries: AclEntries, + changed: bool, +} + +impl AclMgr { + pub const fn new() -> Self { + const INIT: Option = None; + + Self { + entries: [INIT; MAX_ACL_ENTRIES], + changed: false, + } } - pub fn load(psm: &MutexGuard) -> Result { - let mut acl_tlvs = Vec::new(); - psm.get_kv_slice(ACL_KV_ENTRY, &mut acl_tlvs)?; + pub fn erase_all(&mut self) -> Result<(), Error> { + for i in 0..MAX_ACL_ENTRIES { + self.entries[i] = None; + } + + self.changed = true; + + Ok(()) + } + + pub fn add(&mut self, entry: AclEntry) -> Result<(), Error> { + let cnt = self + .entries + .iter() + .flatten() + .filter(|a| a.fab_idx == entry.fab_idx) + .count(); + if cnt >= ENTRIES_PER_FABRIC { + return Err(Error::NoSpace); + } + let index = self + .entries + .iter() + .position(|a| a.is_none()) + .ok_or(Error::NoSpace)?; + self.entries[index] = Some(entry); + + self.changed = true; + + Ok(()) + } + + // Since the entries are fabric-scoped, the index is only for entries with the matching fabric index + pub fn edit(&mut self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> { + let old = self.for_index_in_fabric(index, fab_idx)?; + *old = Some(new); + + self.changed = true; + + Ok(()) + } + + pub fn delete(&mut self, index: u8, fab_idx: u8) -> Result<(), Error> { + let old = self.for_index_in_fabric(index, fab_idx)?; + *old = None; + + self.changed = true; + + Ok(()) + } + + pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { + for i in 0..MAX_ACL_ENTRIES { + if self.entries[i] + .filter(|e| e.fab_idx == Some(fab_idx)) + .is_some() + { + self.entries[i] = None; + } + } + + self.changed = true; + + Ok(()) + } + + pub fn for_each_acl(&self, mut f: T) -> Result<(), Error> + where + T: FnMut(&AclEntry) -> Result<(), Error>, + { + for entry in self.entries.iter().flatten() { + f(entry)?; + } + + Ok(()) + } + + pub fn allow(&self, req: &AccessReq) -> bool { + // PASE Sessions have implicit access grant + if req.accessor.auth_mode == AuthMode::Pase { + return true; + } + for e in self.entries.iter().flatten() { + if e.allow(req) { + return true; + } + } + error!( + "ACL Disallow for subjects {} fab idx {}", + req.accessor.subjects, req.accessor.fab_idx + ); + error!("{}", self); + false + } + + pub fn store(&mut self, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + if self.changed { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let mut wb = WriteBuf::new(&mut buf); + let mut tw = TLVWriter::new(&mut wb); + self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice())?; + + self.changed = false; + } + + Ok(()) + } + + pub fn load(&mut self, psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf)?; + let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; + + self.entries = AclEntries::from_tlv(&root)?; + self.changed = false; + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + if self.changed { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let mut wb = WriteBuf::new(&mut buf); + let mut tw = TLVWriter::new(&mut wb); + self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + psm.set_kv_slice(ACL_KV_ENTRY, wb.into_slice()).await?; + + self.changed = false; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn load_async(&mut self, psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut buf = [0u8; ACL_KV_MAX_SIZE]; + let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; let root = TLVList::new(&acl_tlvs) .iter() .next() .ok_or(Error::Invalid)?; - Ok(Self { - entries: AclEntries::from_tlv(&root)?, - }) + self.entries = AclEntries::from_tlv(&root)?; + self.changed = false; + + Ok(()) } /// Traverse fabric specific entries to find the index @@ -422,169 +596,10 @@ impl AclMgrInner { } } -pub struct AclMgr { - inner: RwLock, - // The Option<> is solely because test execution is faster - // Doing this here adds the least overhead during ACL verification - psm: Option>>, -} - -impl AclMgr { - pub fn new() -> Result { - AclMgr::new_with(true) - } - - pub fn new_with(psm_support: bool) -> Result { - const INIT: Option = None; - let mut psm = None; - - let inner = if !psm_support { - AclMgrInner { - entries: [INIT; MAX_ACL_ENTRIES], - } - } else { - let psm_handle = Psm::get()?; - let inner = { - let psm_lock = psm_handle.lock().unwrap(); - AclMgrInner::load(&psm_lock) - }; - - psm = Some(psm_handle); - inner.unwrap_or({ - // Error loading from PSM - AclMgrInner { - entries: [INIT; MAX_ACL_ENTRIES], - } - }) - }; - Ok(Self { - inner: RwLock::new(inner), - psm, - }) - } - - pub fn erase_all(&self) { - let mut inner = self.inner.write().unwrap(); - for i in 0..MAX_ACL_ENTRIES { - inner.entries[i] = None; - } - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - let _ = inner.store(&psm).map_err(|e| { - error!("Error in storing ACLs {}", e); - }); - } - } - - pub fn add(&self, entry: AclEntry) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let cnt = inner - .entries - .iter() - .flatten() - .filter(|a| a.fab_idx == entry.fab_idx) - .count(); - if cnt >= ENTRIES_PER_FABRIC { - return Err(Error::NoSpace); - } - let index = inner - .entries - .iter() - .position(|a| a.is_none()) - .ok_or(Error::NoSpace)?; - inner.entries[index] = Some(entry); - - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } - } - - // Since the entries are fabric-scoped, the index is only for entries with the matching fabric index - pub fn edit(&self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let old = inner.for_index_in_fabric(index, fab_idx)?; - *old = Some(new); - - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } - } - - pub fn delete(&self, index: u8, fab_idx: u8) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - let old = inner.for_index_in_fabric(index, fab_idx)?; - *old = None; - - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } - } - - pub fn delete_for_fabric(&self, fab_idx: u8) -> Result<(), Error> { - let mut inner = self.inner.write().unwrap(); - - for i in 0..MAX_ACL_ENTRIES { - if inner.entries[i] - .filter(|e| e.fab_idx == Some(fab_idx)) - .is_some() - { - inner.entries[i] = None; - } - } - - if let Some(psm) = self.psm.as_ref() { - let psm = psm.lock().unwrap(); - inner.store(&psm) - } else { - Ok(()) - } - } - - pub fn for_each_acl(&self, mut f: T) -> Result<(), Error> - where - T: FnMut(&AclEntry), - { - let inner = self.inner.read().unwrap(); - for entry in inner.entries.iter().flatten() { - f(entry) - } - Ok(()) - } - - pub fn allow(&self, req: &AccessReq) -> bool { - // PASE Sessions have implicit access grant - if req.accessor.auth_mode == AuthMode::Pase { - return true; - } - let inner = self.inner.read().unwrap(); - for e in inner.entries.iter().flatten() { - if e.allow(req) { - return true; - } - } - error!( - "ACL Disallow for subjects {} fab idx {}", - req.accessor.subjects, req.accessor.fab_idx - ); - error!("{}", self); - false - } -} - -impl std::fmt::Display for AclMgr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let inner = self.inner.read().unwrap(); +impl core::fmt::Display for AclMgr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "ACLS: [")?; - for i in inner.entries.iter().flatten() { + for i in self.entries.iter().flatten() { write!(f, " {{ {:?} }}, ", i)?; } write!(f, "]") @@ -594,22 +609,23 @@ impl std::fmt::Display for AclMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] mod tests { + use core::cell::RefCell; + use crate::{ acl::{gen_noc_cat, AccessorSubjects}, data_model::objects::{Access, Privilege}, interaction_model::messages::GenericPath, }; - use std::sync::Arc; use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; #[test] fn test_basic_empty_subject_target() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Default deny @@ -617,46 +633,46 @@ mod tests { // Deny for session mode mismatch let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Deny for fab idx mismatch let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_subject() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for subject mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112232).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for subject match - target is wildcard let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_cat() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let allow_cat = 0xABCD; let disallow_cat = 0xCAFE; @@ -666,35 +682,35 @@ mod tests { let mut subjects = AccessorSubjects::new(112233); subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); + let accessor = Accessor::new(2, subjects, AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Deny of CAT version mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_cat_version() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let allow_cat = 0xABCD; let disallow_cat = 0xCAFE; @@ -704,32 +720,32 @@ mod tests { let mut subjects = AccessorSubjects::new(112233); subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap(); - let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); + let accessor = Accessor::new(2, subjects, AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match and version more than ACL version let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_target() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); - let mut req = AccessReq::new(&accessor, &path, Access::READ); + let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); // Deny for target mismatch @@ -740,7 +756,7 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for cluster match - subject wildcard @@ -751,11 +767,11 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); // Clean Slate - am.erase_all(); + am.borrow_mut().erase_all().unwrap(); // Allow for endpoint match - subject wildcard let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); @@ -765,11 +781,11 @@ mod tests { device_type: None, }) .unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); // Clean Slate - am.erase_all(); + am.borrow_mut().erase_all().unwrap(); // Allow for exact match let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); @@ -780,16 +796,15 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); } #[test] fn test_privilege() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); - - let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); // Create an Exact Match ACL with View privilege @@ -801,10 +816,10 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Write on an RWVA without admin access - deny - let mut req = AccessReq::new(&accessor, &path, Access::WRITE); + let mut req = AccessReq::new(&accessor, path, Access::WRITE); req.set_target_perms(Access::RWVA); assert_eq!(req.allow(), false); @@ -817,40 +832,40 @@ mod tests { }) .unwrap(); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Write on an RWVA with admin access - allow - let mut req = AccessReq::new(&accessor, &path, Access::WRITE); + let mut req = AccessReq::new(&accessor, path, Access::WRITE); req.set_target_perms(Access::RWVA); assert_eq!(req.allow(), true); } #[test] fn test_delete_for_fabric() { - let am = Arc::new(AclMgr::new_with(false).unwrap()); - am.erase_all(); + let am = RefCell::new(AclMgr::new()); + am.borrow_mut().erase_all().unwrap(); let path = GenericPath::new(Some(1), Some(1234), None); - let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); - let mut req2 = AccessReq::new(&accessor2, &path, Access::READ); + let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); + let mut req2 = AccessReq::new(&accessor2, path, Access::READ); req2.set_target_perms(Access::RWVA); - let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); - let mut req3 = AccessReq::new(&accessor3, &path, Access::READ); + let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, &am); + let mut req3 = AccessReq::new(&accessor3, path, Access::READ); req3.set_target_perms(Access::RWVA); // Allow for subject match - target is wildcard - Fabric idx 2 let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Allow for subject match - target is wildcard - Fabric idx 3 let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.add(new).unwrap(); + am.borrow_mut().add(new).unwrap(); // Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed assert_eq!(req2.allow(), true); assert_eq!(req3.allow(), true); - am.delete_for_fabric(2).unwrap(); + am.borrow_mut().delete_for_fabric(2).unwrap(); assert_eq!(req2.allow(), false); assert_eq!(req3.allow(), true); } diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index b3bb13c..ae2ced8 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -18,6 +18,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; use chrono::{Datelike, TimeZone, Utc}; +use core::fmt::Write; use log::warn; #[derive(Debug)] @@ -279,10 +280,12 @@ impl<'a> CertConsumer for ASN1Writer<'a> { if dt.year() >= 2050 { // If year is >= 2050, ASN.1 requires it to be Generalised Time - let time_str = format!("{}Z", dt.format("%Y%m%d%H%M%S")); + let mut time_str = heapless::String::<32>::new(); + write!(&mut time_str, "{}Z", dt.format("%Y%m%d%H%M%S")).unwrap(); self.write_str(0x18, time_str.as_bytes()) } else { - let time_str = format!("{}Z", dt.format("%y%m%d%H%M%S")); + let mut time_str = heapless::String::<32>::new(); + write!(&mut time_str, "{}Z", dt.format("%y%m%d%H%M%S")).unwrap(); self.write_str(0x17, time_str.as_bytes()) } } diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 360ce31..664125b 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -15,14 +15,17 @@ * limitations under the License. */ -use std::fmt; +use core::fmt; + +extern crate alloc; use crate::{ - crypto::{CryptoKeyPair, KeyPair}, + crypto::KeyPair, error::Error, tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, utils::writebuf::WriteBuf, }; +use alloc::{format, string::String, vec::Vec}; use log::error; use num_derive::FromPrimitive; @@ -591,10 +594,10 @@ impl Cert { } pub fn as_tlv(&self, buf: &mut [u8]) -> Result { - let mut wb = WriteBuf::new(buf, buf.len()); + let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.to_tlv(&mut tw, TagType::Anonymous)?; - Ok(wb.as_slice().len()) + Ok(wb.into_slice().len()) } pub fn as_asn1(&self, buf: &mut [u8]) -> Result { @@ -731,6 +734,8 @@ mod printer; #[cfg(test)] mod tests { + use log::info; + use crate::cert::Cert; use crate::error::Error; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; @@ -811,15 +816,14 @@ mod tests { ]; for input in test_input.iter() { - println!("Testing next input..."); + info!("Testing next input..."); let root = tlv::get_root_node(input).unwrap(); let cert = Cert::from_tlv(&root).unwrap(); let mut buf = [0u8; 1024]; - let buf_len = buf.len(); - let mut wb = WriteBuf::new(&mut buf, buf_len); + let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!(*input, wb.as_slice()); + assert_eq!(*input, wb.into_slice()); } } diff --git a/matter/src/cert/printer.rs b/matter/src/cert/printer.rs index e92dbd4..b933607 100644 --- a/matter/src/cert/printer.rs +++ b/matter/src/cert/printer.rs @@ -18,8 +18,8 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; use chrono::{TimeZone, Utc}; +use core::fmt; use log::warn; -use std::fmt; pub struct CertPrinter<'a, 'b> { level: usize, diff --git a/matter/src/codec/base38.rs b/matter/src/codec/base38.rs index d4c69c1..7b7e758 100644 --- a/matter/src/codec/base38.rs +++ b/matter/src/codec/base38.rs @@ -17,6 +17,10 @@ //! Base38 encoding and decoding functions. +extern crate alloc; + +use alloc::{string::String, vec::Vec}; + use crate::error::Error; const BASE38_CHARS: [char; 38] = [ @@ -97,7 +101,7 @@ pub fn encode(bytes: &[u8], length: Option) -> String { while offset < length { let remaining = length - offset; match remaining.cmp(&2) { - std::cmp::Ordering::Greater => { + core::cmp::Ordering::Greater => { result.push_str(&encode_base38( ((bytes[offset + 2] as u32) << 16) | ((bytes[offset + 1] as u32) << 8) @@ -106,14 +110,14 @@ pub fn encode(bytes: &[u8], length: Option) -> String { )); offset += 3; } - std::cmp::Ordering::Equal => { + core::cmp::Ordering::Equal => { result.push_str(&encode_base38( ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), 4, )); break; } - std::cmp::Ordering::Less => { + core::cmp::Ordering::Less => { result.push_str(&encode_base38(bytes[offset] as u32, 2)); break; } diff --git a/matter/src/core.rs b/matter/src/core.rs index 9f1b13b..7b853b9 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -15,21 +15,19 @@ * limitations under the License. */ +use core::{borrow::Borrow, cell::RefCell}; + use crate::{ acl::AclMgr, - data_model::{ - cluster_basic_information::BasicInfoConfig, core::DataModel, - sdm::dev_att::DevAttDataFetcher, - }, + data_model::{cluster_basic_information::BasicInfoConfig, sdm::failsafe::FailSafe}, error::*, fabric::FabricMgr, - interaction_model::InteractionModel, - mdns::Mdns, + mdns::{Mdns, MdnsMgr}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, - secure_channel::{core::SecureChannel, pake::PaseMgr, spake2p::VerifierData}, - transport, + secure_channel::{pake::PaseMgr, spake2p::VerifierData}, + transport::udp::MATTER_PORT, + utils::{epoch::Epoch, rand::Rand}, }; -use std::sync::Arc; /// Device Commissioning Data pub struct CommissioningData { @@ -40,13 +38,18 @@ pub struct CommissioningData { } /// The primary Matter Object -pub struct Matter { - transport_mgr: transport::mgr::Mgr, - data_model: DataModel, - fabric_mgr: Arc, +pub struct Matter<'a> { + pub fabric_mgr: RefCell, + pub acl_mgr: RefCell, + pub pase_mgr: RefCell, + pub failsafe: RefCell, + pub mdns_mgr: RefCell>, + pub epoch: Epoch, + pub rand: Rand, + pub dev_det: &'a BasicInfoConfig<'a>, } -impl Matter { +impl<'a> Matter<'a> { /// Creates a new Matter object /// /// # Parameters @@ -54,57 +57,87 @@ impl Matter { /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. pub fn new( - dev_det: BasicInfoConfig, - dev_att: Box, - dev_comm: CommissioningData, - ) -> Result, Error> { - let mdns = Mdns::get()?; - mdns.set_values(dev_det.vid, dev_det.pid, &dev_det.device_name); - - let fabric_mgr = Arc::new(FabricMgr::new()?); - let open_comm_window = fabric_mgr.is_empty(); - if open_comm_window { - print_pairing_code_and_qr(&dev_det, &dev_comm, DiscoveryCapabilities::default()); + dev_det: &'a BasicInfoConfig, + mdns: &'a mut dyn Mdns, + epoch: Epoch, + rand: Rand, + ) -> Self { + Self { + fabric_mgr: RefCell::new(FabricMgr::new()), + acl_mgr: RefCell::new(AclMgr::new()), + pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), + failsafe: RefCell::new(FailSafe::new()), + mdns_mgr: RefCell::new(MdnsMgr::new( + dev_det.vid, + dev_det.pid, + dev_det.device_name, + MATTER_PORT, + mdns, + )), + epoch, + rand, + dev_det, } - - let acl_mgr = Arc::new(AclMgr::new()?); - let mut pase = PaseMgr::new(); - let data_model = - DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; - let mut matter = Box::new(Matter { - transport_mgr: transport::mgr::Mgr::new()?, - data_model, - fabric_mgr, - }); - let interaction_model = - Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); - matter.transport_mgr.register_protocol(interaction_model)?; - - if open_comm_window { - pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; - } - - let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); - matter.transport_mgr.register_protocol(secure_channel)?; - Ok(matter) } - /// Returns an Arc to [DataModel] - /// - /// The Data Model is where you express what is the type of your device. Typically - /// once you gets this reference, you acquire the write lock and add your device - /// types, clusters, attributes, commands to the data model. - pub fn get_data_model(&self) -> DataModel { - self.data_model.clone() + pub fn dev_det(&self) -> &BasicInfoConfig { + self.dev_det } - /// Starts the Matter daemon - /// - /// This call does NOT return - /// - /// This call starts the Matter daemon that starts communication with other Matter - /// devices on the network. - pub fn start_daemon(&mut self) -> Result<(), Error> { - self.transport_mgr.start() + pub fn start(&mut self, dev_comm: CommissioningData) -> Result<(), Error> { + let open_comm_window = self.fabric_mgr.borrow().is_empty(); + if open_comm_window { + print_pairing_code_and_qr(self.dev_det, &dev_comm, DiscoveryCapabilities::default()); + + self.pase_mgr.borrow_mut().enable_pase_session( + dev_comm.verifier, + dev_comm.discriminator, + &mut self.mdns_mgr.borrow_mut(), + )?; + } + + Ok(()) + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.fabric_mgr + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.acl_mgr + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.pase_mgr + } +} + +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &RefCell { + &self.failsafe + } +} + +impl<'a> Borrow>> for Matter<'a> { + fn borrow(&self) -> &RefCell> { + &self.mdns_mgr + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &Epoch { + &self.epoch + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &Rand { + &self.rand } } diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index 80c1288..f193b20 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -19,39 +19,118 @@ use log::error; use crate::error::Error; -use super::CryptoKeyPair; +pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) +} -pub struct KeyPairDummy {} +#[derive(Clone)] +pub struct Sha256 {} -impl KeyPairDummy { +impl Sha256 { pub fn new() -> Result { Ok(Self {}) } + + pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> { + Ok(()) + } + + pub fn finish(self, _digest: &mut [u8]) -> Result<(), Error> { + Ok(()) + } } -impl CryptoKeyPair for KeyPairDummy { - fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { +pub struct HmacSha256 {} + +impl HmacSha256 { + pub fn new(_key: &[u8]) -> Result { + error!("This API should never get called"); + Ok(Self {}) + } + + pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) + } + + pub fn finish(self, _out: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + Ok(()) + } +} + +pub struct KeyPair; + +impl KeyPair { + pub fn new() -> Result { + Ok(Self) + } + + pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result { + error!("This API should never get called"); + + Ok(Self {}) + } + + pub fn new_from_public(_pub_key: &[u8]) -> Result { + error!("This API should never get called"); + + Ok(Self {}) + } + + pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); Err(Error::Invalid) } - fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { + + pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { + + pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { + + pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); Err(Error::Invalid) } } + +pub fn pbkdf2_hmac(_pass: &[u8], _iter: usize, _salt: &[u8], _key: &mut [u8]) -> Result<(), Error> { + error!("This API should never get called"); + + Ok(()) +} + +pub fn encrypt_in_place( + _key: &[u8], + _nonce: &[u8], + _ad: &[u8], + _data: &mut [u8], + _data_len: usize, +) -> Result { + Ok(0) +} + +pub fn decrypt_in_place( + _key: &[u8], + _nonce: &[u8], + _ad: &[u8], + _data: &mut [u8], +) -> Result { + Ok(0) +} diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index 9a8495d..fe72337 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -19,8 +19,6 @@ use log::error; use crate::error::Error; -use super::CryptoKeyPair; - pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); Ok(()) @@ -82,26 +80,28 @@ impl KeyPair { Ok(Self {}) } -} -impl CryptoKeyPair for KeyPair { - fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); Err(Error::Invalid) } - fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { + + pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { + + pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { + + pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); Err(Error::Invalid) } - fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { + + pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); Err(Error::Invalid) } diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index e818583..2890fd1 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -15,9 +15,11 @@ * limitations under the License. */ -use std::sync::Arc; +extern crate alloc; -use log::error; +use alloc::sync::Arc; + +use log::{error, info}; use mbedtls::{ bignum::Mpi, cipher::{Authenticated, Cipher}, @@ -28,7 +30,6 @@ use mbedtls::{ x509, }; -use super::CryptoKeyPair; use crate::{ // TODO: We should move ASN1Writer out of Cert, // so Crypto doesn't have to depend on Cert @@ -85,10 +86,8 @@ impl KeyPair { key: Pk::public_from_ec_components(group, pub_key)?, }) } -} -impl CryptoKeyPair for KeyPair { - fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { let tmp_priv = self.key.ec_private()?; let mut tmp_key = Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?; @@ -112,7 +111,7 @@ impl CryptoKeyPair for KeyPair { } } - fn get_public_key(&self, pub_key: &mut [u8]) -> Result { + pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result { let public_key = self.key.ec_public()?; let group = EcGroup::new(EcGroupId::SecP256R1)?; let vec = public_key.to_binary(&group, false)?; @@ -122,7 +121,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { let priv_key_mpi = self.key.ec_private()?; let vec = priv_key_mpi.to_binary()?; @@ -131,7 +130,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { + pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way @@ -149,7 +148,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { + pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way let tmp_key = self.key.ec_private()?; @@ -175,7 +174,7 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { + pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { // mbedtls requires a 'mut' key. Instead of making a change in our Trait, // we just clone the key this way let tmp_key = self.key.ec_public()?; @@ -192,7 +191,7 @@ impl CryptoKeyPair for KeyPair { let mbedtls_sign = &mbedtls_sign[..len]; if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { - println!("The error is {}", e); + info!("The error is {}", e); Err(Error::InvalidSignature) } else { Ok(()) diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index e8b67c1..e448619 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -17,7 +17,6 @@ use crate::error::Error; -use super::CryptoKeyPair; use foreign_types::ForeignTypeRef; use log::error; use openssl::asn1::Asn1Type; @@ -112,10 +111,8 @@ impl KeyPair { KeyType::Private(k) => Ok(&k), } } -} -impl CryptoKeyPair for KeyPair { - fn get_public_key(&self, pub_key: &mut [u8]) -> Result { + pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result { let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let mut bn_ctx = BigNumContext::new()?; let s = self.public_key_point().to_bytes( @@ -128,14 +125,14 @@ impl CryptoKeyPair for KeyPair { Ok(len) } - fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { let s = self.private_key()?.private_key().to_vec(); let len = s.len(); priv_key[..len].copy_from_slice(s.as_slice()); Ok(len) } - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { + pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result { let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; @@ -149,7 +146,7 @@ impl CryptoKeyPair for KeyPair { Ok(deriver.derive(secret)?) } - fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { + pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { let mut builder = X509ReqBuilder::new()?; builder.set_version(0)?; @@ -174,7 +171,7 @@ impl CryptoKeyPair for KeyPair { } } - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { + pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result { // First get the SHA256 of the message let mut h = Hasher::new(MessageDigest::sha256())?; h.update(msg)?; @@ -193,7 +190,7 @@ impl CryptoKeyPair for KeyPair { Ok(64) } - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { + pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { // First get the SHA256 of the message let mut h = Hasher::new(MessageDigest::sha256())?; h.update(msg)?; diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 2473fb0..5c73ff2 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -15,8 +15,6 @@ * limitations under the License. */ -use crate::error::Error; - pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; @@ -35,16 +33,6 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32; pub const EC_SIGNATURE_LEN_BYTES: usize = 64; -// APIs particular to a KeyPair so a KeyPair object can be defined -pub trait CryptoKeyPair { - fn get_csr<'a>(&self, csr: &'a mut [u8]) -> Result<&'a [u8], Error>; - fn get_public_key(&self, pub_key: &mut [u8]) -> Result; - fn get_private_key(&self, priv_key: &mut [u8]) -> Result; - fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result; - fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result; - fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error>; -} - #[cfg(feature = "crypto_esp_mbedtls")] mod crypto_esp_mbedtls; #[cfg(feature = "crypto_esp_mbedtls")] @@ -65,13 +53,26 @@ mod crypto_rustcrypto; #[cfg(feature = "crypto_rustcrypto")] pub use self::crypto_rustcrypto::*; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] pub mod crypto_dummy; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] +pub use self::crypto_dummy::*; #[cfg(test)] mod tests { use crate::error::Error; - use super::{CryptoKeyPair, KeyPair}; + use super::KeyPair; #[test] fn test_verify_msg_success() { diff --git a/matter/src/data_model/cluster_basic_information.rs b/matter/src/data_model/cluster_basic_information.rs index 7c9cada..71c0722 100644 --- a/matter/src/data_model/cluster_basic_information.rs +++ b/matter/src/data_model/cluster_basic_information.rs @@ -15,100 +15,129 @@ * limitations under the License. */ +use core::convert::TryInto; + use super::objects::*; -use crate::error::*; -use num_derive::FromPrimitive; +use crate::{attribute_enum, error::Error, utils::rand::Rand}; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x0028; -#[derive(FromPrimitive)] +#[derive(Clone, Copy, Debug, FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - DMRevision = 0, - VendorId = 2, - ProductId = 4, - HwVer = 7, - SwVer = 9, - SwVerString = 0xa, - SerialNo = 0x0f, + DMRevision(AttrType) = 0, + VendorId(AttrType) = 2, + ProductId(AttrType) = 4, + HwVer(AttrType) = 7, + SwVer(AttrType) = 9, + SwVerString(AttrUtfType) = 0xa, + SerialNo(AttrUtfType) = 0x0f, } +attribute_enum!(Attributes); + #[derive(Default)] -pub struct BasicInfoConfig { +pub struct BasicInfoConfig<'a> { pub vid: u16, pub pid: u16, pub hw_ver: u16, pub sw_ver: u32, - pub sw_ver_str: String, - pub serial_no: String, + pub sw_ver_str: &'a str, + pub serial_no: &'a str, /// Device name; up to 32 characters - pub device_name: String, + pub device_name: &'a str, } -pub struct BasicInfoCluster { - base: Cluster, +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::DMRevision as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::VendorId as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::ProductId as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::HwVer as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SwVer as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SwVerString as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::SerialNo as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[], +}; + +pub struct BasicInfoCluster<'a> { + data_ver: Dataver, + cfg: &'a BasicInfoConfig<'a>, } -impl BasicInfoCluster { - pub fn new(cfg: BasicInfoConfig) -> Result, Error> { - let mut cluster = Box::new(BasicInfoCluster { - base: Cluster::new(ID)?, - }); +impl<'a> BasicInfoCluster<'a> { + pub fn new(cfg: &'a BasicInfoConfig<'a>, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + cfg, + } + } - let attrs = [ - Attribute::new( - Attributes::DMRevision as u16, - AttrValue::Uint8(1), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::VendorId as u16, - AttrValue::Uint16(cfg.vid), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::ProductId as u16, - AttrValue::Uint16(cfg.pid), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::HwVer as u16, - AttrValue::Uint16(cfg.hw_ver), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SwVer as u16, - AttrValue::Uint32(cfg.sw_ver), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SwVerString as u16, - AttrValue::Utf8(cfg.sw_ver_str), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::SerialNo as u16, - AttrValue::Utf8(cfg.serial_no), - Access::RV, - Quality::FIXED, - ), - ]; - cluster.base.add_attributes(&attrs[..])?; - - Ok(cluster) + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::DMRevision(codec) => codec.encode(writer, 1), + Attributes::VendorId(codec) => codec.encode(writer, self.cfg.vid), + Attributes::ProductId(codec) => codec.encode(writer, self.cfg.pid), + Attributes::HwVer(codec) => codec.encode(writer, self.cfg.hw_ver), + Attributes::SwVer(codec) => codec.encode(writer, self.cfg.sw_ver), + Attributes::SwVerString(codec) => codec.encode(writer, self.cfg.sw_ver_str), + Attributes::SerialNo(codec) => codec.encode(writer, self.cfg.serial_no), + } + } + } else { + Ok(()) + } } } -impl ClusterType for BasicInfoCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base +impl<'a> Handler for BasicInfoCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + BasicInfoCluster::read(self, attr, encoder) + } +} + +impl<'a> NonBlockingHandler for BasicInfoCluster<'a> {} + +impl<'a> ChangeNotifier<()> for BasicInfoCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/cluster_on_off.rs b/matter/src/data_model/cluster_on_off.rs index 6864a68..9b17367 100644 --- a/matter/src/data_model/cluster_on_off.rs +++ b/matter/src/data_model/cluster_on_off.rs @@ -15,114 +15,153 @@ * limitations under the License. */ +use core::convert::TryInto; + use super::objects::*; use crate::{ - cmd_enter, - error::*, - interaction_model::{command::CommandReq, core::IMStatusCode}, + attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction, + tlv::TLVElement, utils::rand::Rand, }; use log::info; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x0006; +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - OnOff = 0x0, + OnOff(AttrType) = 0x0, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u32)] pub enum Commands { Off = 0x0, On = 0x01, Toggle = 0x02, } -fn attr_on_off_new() -> Attribute { - // OnOff, Value: false - Attribute::new( - Attributes::OnOff as u16, - AttrValue::Bool(false), - Access::RV, - Quality::PERSISTENT, - ) -} +command_enum!(Commands); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::OnOff as u16, + Access::RV, + Quality::PERSISTENT, + ), + ], + commands: &[ + CommandsDiscriminants::Off as _, + CommandsDiscriminants::On as _, + CommandsDiscriminants::Toggle as _, + ], +}; pub struct OnOffCluster { - base: Cluster, + data_ver: Dataver, + on: bool, } impl OnOffCluster { - pub fn new() -> Result, Error> { - let mut cluster = Box::new(OnOffCluster { - base: Cluster::new(ID)?, - }); - cluster.base.add_attribute(attr_on_off_new())?; - Ok(cluster) - } -} - -impl ClusterType for OnOffCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + on: false, + } } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { + pub fn set(&mut self, on: bool) { + if self.on != on { + self.on = on; + self.data_ver.changed(); + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::OnOff(codec) => codec.encode(writer, self.on), + } + } + } else { + Ok(()) + } + } + + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + let data = data.with_dataver(self.data_ver.get())?; + + match attr.attr_id.try_into()? { + Attributes::OnOff(codec) => self.set(codec.decode(data)?), + } + + self.data_ver.changed(); + + Ok(()) + } + + pub fn invoke( + &mut self, + cmd: &CmdDetails, + _data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { Commands::Off => { cmd_enter!("Off"); - let value = self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap(); - if AttrValue::Bool(true) == *value { - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(false)) - .map_err(|_| IMStatusCode::Failure)?; - } - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(false); } Commands::On => { cmd_enter!("On"); - let value = self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap(); - if AttrValue::Bool(false) == *value { - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(true)) - .map_err(|_| IMStatusCode::Failure)?; - } - - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(true); } Commands::Toggle => { cmd_enter!("Toggle"); - let value = match self - .base - .read_attribute_raw(Attributes::OnOff as u16) - .unwrap() - { - &AttrValue::Bool(v) => v, - _ => false, - }; - self.base - .write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(!value)) - .map_err(|_| IMStatusCode::Failure)?; - cmd_req.trans.complete(); - Err(IMStatusCode::Success) + self.set(!self.on); } } + + self.data_ver.changed(); + + Ok(()) + } +} + +impl Handler for OnOffCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + OnOffCluster::read(self, attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + OnOffCluster::write(self, attr, data) + } + + fn invoke( + &mut self, + _transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + OnOffCluster::invoke(self, cmd, data, encoder) + } +} + +// TODO: Might be removed once the `on` member is externalized +impl NonBlockingHandler for OnOffCluster {} + +impl ChangeNotifier<()> for OnOffCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/cluster_template.rs b/matter/src/data_model/cluster_template.rs index 6555e9e..c103812 100644 --- a/matter/src/data_model/cluster_template.rs +++ b/matter/src/data_model/cluster_template.rs @@ -16,29 +16,59 @@ */ use crate::{ - data_model::objects::{Cluster, ClusterType}, + data_model::objects::{Cluster, Handler}, error::Error, + utils::rand::Rand, +}; + +use super::objects::{ + AttrDataEncoder, AttrDetails, ChangeNotifier, Dataver, NonBlockingHandler, ATTRIBUTE_LIST, + FEATURE_MAP, }; const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031; -pub struct TemplateCluster { - base: Cluster, -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: CLUSTER_NETWORK_COMMISSIONING_ID as _, + feature_map: 0, + attributes: &[FEATURE_MAP, ATTRIBUTE_LIST], + commands: &[], +}; -impl ClusterType for TemplateCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } +pub struct TemplateCluster { + data_ver: Dataver, } impl TemplateCluster { - pub fn new() -> Result, Error> { - Ok(Box::new(Self { - base: Cluster::new(CLUSTER_NETWORK_COMMISSIONING_ID)?, - })) + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + Err(Error::AttributeNotFound) + } + } else { + Ok(()) + } + } +} + +impl Handler for TemplateCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + TemplateCluster::read(self, attr, encoder) + } +} + +impl NonBlockingHandler for TemplateCluster {} + +impl ChangeNotifier<()> for TemplateCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs new file mode 100644 index 0000000..005871d --- /dev/null +++ b/matter/src/data_model/core.rs @@ -0,0 +1,199 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::cell::RefCell; + +use super::objects::*; +use crate::{ + acl::{Accessor, AclMgr}, + error::*, + interaction_model::core::{Interaction, Transaction}, + tlv::TLVWriter, + transport::packet::Packet, +}; + +pub struct DataModel<'a, T> { + pub acl_mgr: &'a RefCell, + pub node: &'a Node<'a>, + pub handler: T, +} + +impl<'a, T> DataModel<'a, T> { + pub const fn new(acl_mgr: &'a RefCell, node: &'a Node<'a>, handler: T) -> Self { + Self { + acl_mgr, + node, + handler, + } + } + + pub fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result + where + T: Handler, + { + let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + match interaction { + Interaction::Read(req) => { + for item in self.node.read(req, &accessor) { + AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?; + } + } + Interaction::Write(req) => { + for item in self.node.write(req, &accessor) { + AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?; + } + } + Interaction::Invoke(req) => { + for item in self.node.invoke(req, &accessor) { + CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?; + } + } + Interaction::Timed(_) => (), + } + + interaction.complete_tx(tx, transaction) + } + + #[cfg(feature = "nightly")] + pub async fn handle_async<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> + where + T: super::objects::asynch::AsyncHandler, + { + let accessor = Accessor::for_session(transaction.session(), self.acl_mgr); + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + match interaction { + Interaction::Read(req) => { + for item in self.node.read(req, &accessor) { + AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?; + } + } + Interaction::Write(req) => { + for item in self.node.write(req, &accessor) { + AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?; + } + } + Interaction::Invoke(req) => { + for item in self.node.invoke(req, &accessor) { + CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw) + .await?; + } + } + Interaction::Timed(_) => (), + } + + interaction.complete_tx(tx, transaction) + } +} + +pub trait DataHandler { + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result; +} + +impl DataHandler for &mut T +where + T: DataHandler, +{ + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + (**self).handle(interaction, tx, transaction) + } +} + +impl<'a, T> DataHandler for DataModel<'a, T> +where + T: Handler, +{ + fn handle( + &mut self, + interaction: &Interaction, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + DataModel::handle(self, interaction, tx, transaction) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::objects::asynch::AsyncHandler, + error::Error, + interaction_model::core::{Interaction, Transaction}, + transport::packet::Packet, + }; + + use super::DataModel; + + pub trait AsyncDataHandler { + async fn handle<'p>( + &mut self, + interaction: &Interaction, + tx: &'p mut Packet, + transaction: &mut Transaction, + ) -> Result, Error>; + } + + impl AsyncDataHandler for &mut T + where + T: AsyncDataHandler, + { + async fn handle<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> { + (**self).handle(interaction, tx, transaction).await + } + } + + impl<'a, T> AsyncDataHandler for DataModel<'a, T> + where + T: AsyncHandler, + { + async fn handle<'p>( + &mut self, + interaction: &Interaction<'_>, + tx: &'p mut Packet<'_>, + transaction: &mut Transaction<'_, '_>, + ) -> Result, Error> { + DataModel::handle_async(self, interaction, tx, transaction).await + } + } +} diff --git a/matter/src/data_model/core/mod.rs b/matter/src/data_model/core/mod.rs deleted file mode 100644 index 4386cab..0000000 --- a/matter/src/data_model/core/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use self::subscribe::SubsCtx; - -use super::{ - cluster_basic_information::BasicInfoConfig, - device_types::device_type_add_root_node, - objects::{self, *}, - sdm::dev_att::DevAttDataFetcher, - system_model::descriptor::DescriptorCluster, -}; -use crate::{ - acl::{AccessReq, Accessor, AccessorSubjects, AclMgr, AuthMode}, - error::*, - fabric::FabricMgr, - interaction_model::{ - command::CommandReq, - core::{IMStatusCode, OpCode}, - messages::{ - ib::{self, AttrData, DataVersionFilter}, - msg::{self, InvReq, ReadReq, WriteReq}, - GenericPath, - }, - InteractionConsumer, Transaction, - }, - secure_channel::pake::PaseMgr, - tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, - transport::{ - proto_demux::ResponseRequired, - session::{Session, SessionMode}, - }, -}; -use log::{error, info}; -use std::sync::{Arc, RwLock}; - -#[derive(Clone)] -pub struct DataModel { - pub node: Arc>>, - acl_mgr: Arc, -} - -impl DataModel { - pub fn new( - dev_details: BasicInfoConfig, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - pase_mgr: PaseMgr, - ) -> Result { - let dm = DataModel { - node: Arc::new(RwLock::new(Node::new()?)), - acl_mgr: acl_mgr.clone(), - }; - { - let mut node = dm.node.write()?; - node.set_changes_cb(Box::new(dm.clone())); - device_type_add_root_node( - &mut node, - dev_details, - dev_att, - fabric_mgr, - acl_mgr, - pase_mgr, - )?; - } - Ok(dm) - } - - // Encode a write attribute from a path that may or may not be wildcard - fn handle_write_attr_path( - node: &mut Node, - accessor: &Accessor, - attr_data: &AttrData, - tw: &mut TLVWriter, - ) { - let gen_path = attr_data.path.to_gp(); - let mut encoder = AttrWriteEncoder::new(tw, TagType::Anonymous); - encoder.set_path(gen_path); - - // The unsupported pieces of the wildcard path - if attr_data.path.cluster.is_none() { - encoder.encode_status(IMStatusCode::UnsupportedCluster, 0); - return; - } - if attr_data.path.attr.is_none() { - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); - return; - } - - // Get the data - let write_data = match &attr_data.data { - EncodeValue::Closure(_) | EncodeValue::Value(_) => { - error!("Not supported"); - return; - } - EncodeValue::Tlv(t) => t, - }; - - if gen_path.is_wildcard() { - // This is a wildcard path, skip error - // This is required because there could be access control errors too that need - // to be taken care of. - encoder.skip_error(); - } - let mut attr = AttrDetails { - // will be udpated in the loop below - attr_id: 0, - list_index: attr_data.path.list_index, - fab_filter: false, - fab_idx: accessor.fab_idx, - }; - - let result = node.for_each_cluster_mut(&gen_path, |path, c| { - if attr_data.data_ver.is_some() && Some(c.base().get_dataver()) != attr_data.data_ver { - encoder.encode_status(IMStatusCode::DataVersionMismatch, 0); - return Ok(()); - } - - attr.attr_id = path.leaf.unwrap_or_default() as u16; - encoder.set_path(*path); - let mut access_req = AccessReq::new(accessor, path, Access::WRITE); - let r = match Cluster::write_attribute(c, &mut access_req, write_data, &attr) { - Ok(_) => IMStatusCode::Success, - Err(e) => e, - }; - encoder.encode_status(r, 0); - Ok(()) - }); - if let Err(e) = result { - // We hit this only if this is a non-wildcard path and some parts of the path are missing - encoder.encode_status(e, 0); - } - } - - // Handle command from a path that may or may not be wildcard - fn handle_command_path(node: &mut Node, cmd_req: &mut CommandReq) { - let wildcard = cmd_req.cmd.path.is_wildcard(); - let path = cmd_req.cmd.path; - - let result = node.for_each_cluster_mut(&path, |path, c| { - cmd_req.cmd.path = *path; - let result = c.handle_command(cmd_req); - if let Err(e) = result { - // It is likely that we might have to do an 'Access' aware traversal - // if there are other conditions in the wildcard scenario that shouldn't be - // encoded as CmdStatus - if !(wildcard && e == IMStatusCode::UnsupportedCommand) { - let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - } - } - Ok(()) - }); - if !wildcard { - if let Err(e) = result { - // We hit this only if this is a non-wildcard path - let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - } - } - } - - fn sess_to_accessor(&self, sess: &Session) -> Accessor { - match sess.get_session_mode() { - SessionMode::Case(c) => { - let mut subject = - AccessorSubjects::new(sess.get_peer_node_id().unwrap_or_default()); - for i in c.cat_ids { - if i != 0 { - let _ = subject.add_catid(i); - } - } - Accessor::new(c.fab_idx, subject, AuthMode::Case, self.acl_mgr.clone()) - } - SessionMode::Pase => Accessor::new( - 0, - AccessorSubjects::new(1), - AuthMode::Pase, - self.acl_mgr.clone(), - ), - - SessionMode::PlainText => Accessor::new( - 0, - AccessorSubjects::new(1), - AuthMode::Invalid, - self.acl_mgr.clone(), - ), - } - } - - /// Returns true if the path matches the cluster path and the data version is a match - fn data_filter_matches( - filters: &Option<&TLVArray>, - path: &GenericPath, - data_ver: u32, - ) -> bool { - if let Some(filters) = *filters { - for filter in filters.iter() { - // TODO: No handling of 'node' comparision yet - if Some(filter.path.endpoint) == path.endpoint - && Some(filter.path.cluster) == path.cluster - && filter.data_ver == data_ver - { - return true; - } - } - } - false - } -} - -pub mod read; -pub mod subscribe; - -/// Type of Resume Request -enum ResumeReq { - Subscribe(subscribe::SubsCtx), - Read(read::ResumeReadReq), -} - -impl objects::ChangeConsumer for DataModel { - fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error> { - endpoint.add_cluster(DescriptorCluster::new(id, self.clone())?)?; - Ok(()) - } -} - -impl InteractionConsumer for DataModel { - fn consume_write_attr( - &self, - write_req: &WriteReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let accessor = self.sess_to_accessor(trans.session); - - tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; - let mut node = self.node.write().unwrap(); - for attr_data in write_req.write_requests.iter() { - DataModel::handle_write_attr_path(&mut node, &accessor, &attr_data, tw); - } - tw.end_container()?; - - Ok(()) - } - - fn consume_read_attr( - &self, - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let mut resume_from = None; - let root = tlv::get_root_node(rx_buf)?; - let req = ReadReq::from_tlv(&root)?; - self.handle_read_req(&req, trans, tw, &mut resume_from)?; - if resume_from.is_some() { - // This is a multi-hop read transaction, remember this read request - let resume = read::ResumeReadReq::new(rx_buf, &resume_from)?; - if !trans.exch.is_data_none() { - error!("Exchange data already set, and multi-hop read"); - return Err(Error::InvalidState); - } - trans.exch.set_data_boxed(Box::new(ResumeReq::Read(resume))); - } - Ok(()) - } - - fn consume_invoke_cmd( - &self, - inv_req_msg: &InvReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error> { - let mut node = self.node.write().unwrap(); - if let Some(inv_requests) = &inv_req_msg.inv_requests { - // Array of InvokeResponse IBs - tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?; - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - info!("Invoke Commmand Handler executing: {:?}", i.path); - let mut cmd_req = CommandReq { - cmd: i.path, - data, - trans, - resp: tw, - }; - DataModel::handle_command_path(&mut node, &mut cmd_req); - } - tw.end_container()?; - } - - Ok(()) - } - - fn consume_status_report( - &self, - req: &msg::StatusResp, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if let Some(mut resume) = trans.exch.take_data_boxed::() { - let result = match *resume { - ResumeReq::Read(ref mut read) => self.handle_resume_read(read, trans, tw)?, - - ResumeReq::Subscribe(ref mut ctx) => ctx.handle_status_report(trans, tw, self)?, - }; - trans.exch.set_data_boxed(resume); - Ok(result) - } else { - // Nothing to do for now - trans.complete(); - info!("Received status report with status {:?}", req.status); - Ok((OpCode::Reserved, ResponseRequired::No)) - } - } - - fn consume_subscribe( - &self, - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if !trans.exch.is_data_none() { - error!("Exchange data already set!"); - return Err(Error::InvalidState); - } - let ctx = SubsCtx::new(rx_buf, trans, tw, self)?; - trans - .exch - .set_data_boxed(Box::new(ResumeReq::Subscribe(ctx))); - Ok((OpCode::ReportData, ResponseRequired::Yes)) - } -} - -/// Encoder for generating a response to a write request -pub struct AttrWriteEncoder<'a, 'b, 'c> { - tw: &'a mut TLVWriter<'b, 'c>, - tag: TagType, - path: GenericPath, - skip_error: bool, -} -impl<'a, 'b, 'c> AttrWriteEncoder<'a, 'b, 'c> { - pub fn new(tw: &'a mut TLVWriter<'b, 'c>, tag: TagType) -> Self { - Self { - tw, - tag, - path: Default::default(), - skip_error: false, - } - } - - pub fn skip_error(&mut self) { - self.skip_error = true; - } - - pub fn set_path(&mut self, path: GenericPath) { - self.path = path; - } -} - -impl<'a, 'b, 'c> Encoder for AttrWriteEncoder<'a, 'b, 'c> { - fn encode(&mut self, _value: EncodeValue) { - // Only status encodes for AttrWriteResponse - } - - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) { - if self.skip_error && status != IMStatusCode::Success { - // Don't encode errors - return; - } - let resp = ib::AttrStatus::new(&self.path, status, cluster_status); - let _ = resp.to_tlv(self.tw, self.tag); - } -} diff --git a/matter/src/data_model/core/read.rs b/matter/src/data_model/core/read.rs deleted file mode 100644 index 07eb1a3..0000000 --- a/matter/src/data_model/core/read.rs +++ /dev/null @@ -1,319 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::{ - acl::{AccessReq, Accessor}, - data_model::{core::DataModel, objects::*}, - error::*, - interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{self, DataVersionFilter}, - msg::{self, ReadReq, ReportDataTag::MoreChunkedMsgs, ReportDataTag::SupressResponse}, - GenericPath, - }, - Transaction, - }, - tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, - transport::{packet::Packet, proto_demux::ResponseRequired}, - utils::writebuf::WriteBuf, - wb_shrink, wb_unshrink, -}; -use log::error; - -/// Encoder for generating a response to a read request -pub struct AttrReadEncoder<'a, 'b, 'c> { - tw: &'a mut TLVWriter<'b, 'c>, - data_ver: u32, - path: GenericPath, - skip_error: bool, - data_ver_filters: Option<&'a TLVArray<'a, DataVersionFilter>>, - is_buffer_full: bool, -} - -impl<'a, 'b, 'c> AttrReadEncoder<'a, 'b, 'c> { - pub fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self { - Self { - tw, - data_ver: 0, - skip_error: false, - path: Default::default(), - data_ver_filters: None, - is_buffer_full: false, - } - } - - pub fn skip_error(&mut self, skip: bool) { - self.skip_error = skip; - } - - pub fn set_data_ver(&mut self, data_ver: u32) { - self.data_ver = data_ver; - } - - pub fn set_data_ver_filters(&mut self, filters: &'a TLVArray<'a, DataVersionFilter>) { - self.data_ver_filters = Some(filters); - } - - pub fn set_path(&mut self, path: GenericPath) { - self.path = path; - } - - pub fn is_buffer_full(&self) -> bool { - self.is_buffer_full - } -} - -impl<'a, 'b, 'c> Encoder for AttrReadEncoder<'a, 'b, 'c> { - fn encode(&mut self, value: EncodeValue) { - let resp = ib::AttrResp::Data(ib::AttrData::new( - Some(self.data_ver), - ib::AttrPath::new(&self.path), - value, - )); - - let anchor = self.tw.get_tail(); - if resp.to_tlv(self.tw, TagType::Anonymous).is_err() { - self.is_buffer_full = true; - self.tw.rewind_to(anchor); - } - } - - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) { - if !self.skip_error { - let resp = - ib::AttrResp::Status(ib::AttrStatus::new(&self.path, status, cluster_status)); - let _ = resp.to_tlv(self.tw, TagType::Anonymous); - } - } -} - -/// State to maintain when a Read Request needs to be resumed -/// resumed - the next chunk of the read needs to be returned -#[derive(Default)] -pub struct ResumeReadReq { - /// The Read Request Attribute Path that caused chunking, and this is the path - /// that needs to be resumed. - pub pending_req: Option>, - - /// The Attribute that couldn't be encoded because our buffer got full. The next chunk - /// will start encoding from this attribute onwards. - /// Note that given wildcard reads, one PendingPath in the member above can generated - /// multiple encode paths. Hence this has to be maintained separately. - pub resume_from: Option, -} -impl ResumeReadReq { - pub fn new(rx_buf: &[u8], resume_from: &Option) -> Result { - let mut packet = Packet::new_rx()?; - let dst = packet.as_borrow_slice(); - - let src_len = rx_buf.len(); - dst[..src_len].copy_from_slice(rx_buf); - packet.get_parsebuf()?.set_len(src_len); - Ok(ResumeReadReq { - pending_req: Some(packet), - resume_from: *resume_from, - }) - } -} - -impl DataModel { - pub fn read_attribute_raw( - &self, - endpoint: EndptId, - cluster: ClusterId, - attr: AttrId, - ) -> Result { - let node = self.node.read().unwrap(); - let cluster = node.get_cluster(endpoint, cluster)?; - cluster.base().read_attribute_raw(attr).map(|a| a.clone()) - } - /// Encode a read attribute from a path that may or may not be wildcard - /// - /// If the buffer gets full while generating the read response, we will return - /// an Err(path), where the path is the path that we should resume from, for the next chunk. - /// This facilitates chunk management - fn handle_read_attr_path( - node: &Node, - accessor: &Accessor, - attr_encoder: &mut AttrReadEncoder, - attr_details: &mut AttrDetails, - resume_from: &mut Option, - ) -> Result<(), Error> { - let mut status = Ok(()); - let path = attr_encoder.path; - - // Skip error reporting for wildcard paths, don't for concrete paths - attr_encoder.skip_error(path.is_wildcard()); - - let result = node.for_each_attribute(&path, |path, c| { - // Ignore processing if data filter matches. - // For a wildcard attribute, this may end happening unnecessarily for all attributes, although - // a single skip for the cluster is sufficient. That requires us to replace this for_each with a - // for_each_cluster - let cluster_data_ver = c.base().get_dataver(); - if Self::data_filter_matches(&attr_encoder.data_ver_filters, path, cluster_data_ver) { - return Ok(()); - } - - // The resume_from indicates that this is the next chunk of a previous Read Request. In such cases, we - // need to skip until we hit this path. - if let Some(r) = resume_from { - // If resume_from is valid, and we haven't hit the resume_from yet, skip encoding - if r != path { - return Ok(()); - } else { - // Else, wipe out the resume_from so subsequent paths can be encoded - *resume_from = None; - } - } - - attr_details.attr_id = path.leaf.unwrap_or_default() as u16; - // Overwrite the previous path with the concrete path - attr_encoder.set_path(*path); - // Set the cluster's data version - attr_encoder.set_data_ver(cluster_data_ver); - let mut access_req = AccessReq::new(accessor, path, Access::READ); - Cluster::read_attribute(c, &mut access_req, attr_encoder, attr_details); - if attr_encoder.is_buffer_full() { - // Buffer is full, next time resume from this attribute - *resume_from = Some(*path); - status = Err(Error::NoSpace); - } - Ok(()) - }); - if let Err(e) = result { - // We hit this only if this is a non-wildcard path - attr_encoder.encode_status(e, 0); - } - status - } - - /// Process an array of Attribute Read Requests - /// - /// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise - /// the read is complete - pub(super) fn handle_read_attr_array( - &self, - read_req: &ReadReq, - trans: &mut Transaction, - old_tw: &mut TLVWriter, - resume_from: &mut Option, - ) -> Result<(), Error> { - let old_wb = old_tw.get_buf(); - // Note, this function may be called from multiple places: a) an actual read - // request, a b) resumed read request, c) subscribe request or d) resumed subscribe - // request. Hopefully 18 is sufficient to address all those scenarios. - // - // This is the amount of space we reserve for other things to be attached towards - // the end - const RESERVE_SIZE: usize = 24; - let mut new_wb = wb_shrink!(old_wb, RESERVE_SIZE); - let mut tw = TLVWriter::new(&mut new_wb); - - let mut attr_encoder = AttrReadEncoder::new(&mut tw); - if let Some(filters) = &read_req.dataver_filters { - attr_encoder.set_data_ver_filters(filters); - } - - if let Some(attr_requests) = &read_req.attr_requests { - let accessor = self.sess_to_accessor(trans.session); - let mut attr_details = AttrDetails::new(accessor.fab_idx, read_req.fabric_filtered); - let node = self.node.read().unwrap(); - attr_encoder - .tw - .start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; - - let mut result = Ok(()); - for attr_path in attr_requests.iter() { - attr_encoder.set_path(attr_path.to_gp()); - // Extract the attr_path fields into various structures - attr_details.list_index = attr_path.list_index; - result = DataModel::handle_read_attr_path( - &node, - &accessor, - &mut attr_encoder, - &mut attr_details, - resume_from, - ); - if result.is_err() { - break; - } - } - // Now that all the read reports are captured, let's use the old_tw that is - // the full writebuf, and hopefully as all the necessary space to store this - wb_unshrink!(old_wb, new_wb); - old_tw.end_container()?; // Finish the AttrReports - - if result.is_err() { - // If there was an error, indicate chunking. The resume_read_req would have been - // already populated in the loop above. - old_tw.bool(TagType::Context(MoreChunkedMsgs as u8), true)?; - } else { - // A None resume_from indicates no chunking - *resume_from = None; - } - } - Ok(()) - } - - /// Handle a read request - /// - /// This could be called from an actual read request or a resumed read request. Subscription - /// requests do not come to this function. - /// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise - /// the read is complete - pub fn handle_read_req( - &self, - read_req: &ReadReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - resume_from: &mut Option, - ) -> Result<(OpCode, ResponseRequired), Error> { - tw.start_struct(TagType::Anonymous)?; - - self.handle_read_attr_array(read_req, trans, tw, resume_from)?; - - if resume_from.is_none() { - tw.bool(TagType::Context(SupressResponse as u8), true)?; - // Mark transaction complete, if not chunked - trans.complete(); - } - tw.end_container()?; - Ok((OpCode::ReportData, ResponseRequired::Yes)) - } - - /// Handle a resumed read request - pub fn handle_resume_read( - &self, - resume_read_req: &mut ResumeReadReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - if let Some(packet) = resume_read_req.pending_req.as_mut() { - let rx_buf = packet.get_parsebuf()?.as_borrow_slice(); - let root = tlv::get_root_node(rx_buf)?; - let req = ReadReq::from_tlv(&root)?; - - self.handle_read_req(&req, trans, tw, &mut resume_read_req.resume_from) - } else { - // No pending req, is that even possible? - error!("This shouldn't have happened"); - Ok((OpCode::Reserved, ResponseRequired::No)) - } - } -} diff --git a/matter/src/data_model/core/subscribe.rs b/matter/src/data_model/core/subscribe.rs deleted file mode 100644 index a65ee1f..0000000 --- a/matter/src/data_model/core/subscribe.rs +++ /dev/null @@ -1,142 +0,0 @@ -/* - * - * Copyright (c) 2023 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::sync::atomic::{AtomicU32, Ordering}; - -use crate::{ - error::Error, - interaction_model::{ - core::OpCode, - messages::{ - msg::{self, SubscribeReq, SubscribeResp}, - GenericPath, - }, - }, - tlv::{self, get_root_node_struct, FromTLV, TLVWriter, TagType, ToTLV}, - transport::proto_demux::ResponseRequired, -}; - -use super::{read::ResumeReadReq, DataModel, Transaction}; - -static SUBS_ID: AtomicU32 = AtomicU32::new(1); - -#[derive(PartialEq)] -enum SubsState { - Confirming, - Confirmed, -} - -pub struct SubsCtx { - state: SubsState, - id: u32, - resume_read_req: Option, -} - -impl SubsCtx { - pub fn new( - rx_buf: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - ) -> Result { - let root = get_root_node_struct(rx_buf)?; - let req = SubscribeReq::from_tlv(&root)?; - - let mut ctx = SubsCtx { - state: SubsState::Confirming, - // TODO - id: SUBS_ID.fetch_add(1, Ordering::SeqCst), - resume_read_req: None, - }; - - let mut resume_from = None; - ctx.do_read(&req, trans, tw, dm, &mut resume_from)?; - if resume_from.is_some() { - // This is a multi-hop read transaction, remember this read request - ctx.resume_read_req = Some(ResumeReadReq::new(rx_buf, &resume_from)?); - } - Ok(ctx) - } - - pub fn handle_status_report( - &mut self, - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - ) -> Result<(OpCode, ResponseRequired), Error> { - if self.state != SubsState::Confirming { - // Not relevant for us - trans.complete(); - return Err(Error::Invalid); - } - - // Is there a previous resume read pending - if self.resume_read_req.is_some() { - let mut resume_read_req = self.resume_read_req.take().unwrap(); - if let Some(packet) = resume_read_req.pending_req.as_mut() { - let rx_buf = packet.get_parsebuf()?.as_borrow_slice(); - let root = tlv::get_root_node(rx_buf)?; - let req = SubscribeReq::from_tlv(&root)?; - - self.do_read(&req, trans, tw, dm, &mut resume_read_req.resume_from)?; - if resume_read_req.resume_from.is_some() { - // More chunks are pending, setup resume_read_req again - self.resume_read_req = Some(resume_read_req); - } - - return Ok((OpCode::ReportData, ResponseRequired::Yes)); - } - } - - // We are here implies that the read is now complete - self.confirm_subscription(trans, tw) - } - - fn confirm_subscription( - &mut self, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - self.state = SubsState::Confirmed; - - // TODO - let resp = SubscribeResp::new(self.id, 40); - resp.to_tlv(tw, TagType::Anonymous)?; - trans.complete(); - Ok((OpCode::SubscriptResponse, ResponseRequired::Yes)) - } - - fn do_read( - &mut self, - req: &SubscribeReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - dm: &DataModel, - resume_from: &mut Option, - ) -> Result<(), Error> { - let read_req = req.to_read_req(); - tw.start_struct(TagType::Anonymous)?; - tw.u32( - TagType::Context(msg::ReportDataTag::SubscriptionId as u8), - self.id, - )?; - dm.handle_read_attr_array(&read_req, trans, tw, resume_from)?; - tw.end_container()?; - - Ok(()) - } -} diff --git a/matter/src/data_model/device_types.rs b/matter/src/data_model/device_types.rs index 9c37971..fbce494 100644 --- a/matter/src/data_model/device_types.rs +++ b/matter/src/data_model/device_types.rs @@ -15,60 +15,14 @@ * limitations under the License. */ -use super::cluster_basic_information::BasicInfoCluster; -use super::cluster_basic_information::BasicInfoConfig; -use super::cluster_on_off::OnOffCluster; -use super::objects::*; -use super::sdm::admin_commissioning::AdminCommCluster; -use super::sdm::dev_att::DevAttDataFetcher; -use super::sdm::general_commissioning::GenCommCluster; -use super::sdm::noc::NocCluster; -use super::sdm::nw_commissioning::NwCommCluster; -use super::system_model::access_control::AccessControlCluster; -use crate::acl::AclMgr; -use crate::error::*; -use crate::fabric::FabricMgr; -use crate::secure_channel::pake::PaseMgr; -use std::sync::Arc; -use std::sync::RwLockWriteGuard; +use super::objects::DeviceType; pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType { dtype: 0x0016, drev: 1, }; -type WriteNode<'a> = RwLockWriteGuard<'a, Box>; - -pub fn device_type_add_root_node( - node: &mut WriteNode, - dev_info: BasicInfoConfig, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - pase_mgr: PaseMgr, -) -> Result { - // Add the root endpoint - let endpoint = node.add_endpoint(DEV_TYPE_ROOT_NODE)?; - if endpoint != 0 { - // Somehow endpoint 0 was already added, this shouldn't be the case - return Err(Error::Invalid); - }; - // Add the mandatory clusters - node.add_cluster(0, BasicInfoCluster::new(dev_info)?)?; - let general_commissioning = GenCommCluster::new()?; - let failsafe = general_commissioning.failsafe(); - node.add_cluster(0, general_commissioning)?; - node.add_cluster(0, NwCommCluster::new()?)?; - node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?; - node.add_cluster( - 0, - NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?, - )?; - node.add_cluster(0, AccessControlCluster::new(acl_mgr)?)?; - Ok(endpoint) -} - -const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType { +pub const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType { dtype: 0x0100, drev: 2, }; @@ -77,9 +31,3 @@ pub const DEV_TYPE_ON_SMART_SPEAKER: DeviceType = DeviceType { dtype: 0x0022, drev: 2, }; - -pub fn device_type_add_on_off_light(node: &mut WriteNode) -> Result { - let endpoint = node.add_endpoint(DEV_TYPE_ON_OFF_LIGHT)?; - node.add_cluster(endpoint, OnOffCluster::new()?)?; - Ok(endpoint) -} diff --git a/matter/src/data_model/mod.rs b/matter/src/data_model/mod.rs index c347941..c76e07c 100644 --- a/matter/src/data_model/mod.rs +++ b/matter/src/data_model/mod.rs @@ -20,8 +20,9 @@ pub mod device_types; pub mod objects; pub mod cluster_basic_information; -pub mod cluster_media_playback; +// TODO pub mod cluster_media_playback; pub mod cluster_on_off; pub mod cluster_template; +pub mod root_endpoint; pub mod sdm; pub mod system_model; diff --git a/matter/src/data_model/objects/attribute.rs b/matter/src/data_model/objects/attribute.rs index 28be5c6..69fd544 100644 --- a/matter/src/data_model/objects/attribute.rs +++ b/matter/src/data_model/objects/attribute.rs @@ -15,15 +15,11 @@ * limitations under the License. */ -use super::{AttrId, GlobalElements, Privilege}; -use crate::{ - error::*, - // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{TLVElement, TLVWriter, TagType, ToTLV}, -}; +use crate::data_model::objects::GlobalElements; + +use super::{AttrId, Privilege}; use bitflags::bitflags; -use log::error; -use std::fmt::{self, Debug, Formatter}; +use core::fmt::{self, Debug}; bitflags! { #[derive(Default)] @@ -83,110 +79,24 @@ bitflags! { } } -/* This file needs some major revamp. - * - instead of allocating all over the heap, we should use some kind of slab/block allocator - * - instead of arrays, can use linked-lists to conserve space and avoid the internal fragmentation - */ - -#[derive(PartialEq, PartialOrd, Clone)] -pub enum AttrValue { - Int64(i64), - Uint8(u8), - Uint16(u16), - Uint32(u32), - Uint64(u64), - Bool(bool), - Utf8(String), - Custom, -} - -impl Debug for AttrValue { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - match &self { - AttrValue::Int64(v) => write!(f, "{:?}", *v), - AttrValue::Uint8(v) => write!(f, "{:?}", *v), - AttrValue::Uint16(v) => write!(f, "{:?}", *v), - AttrValue::Uint32(v) => write!(f, "{:?}", *v), - AttrValue::Uint64(v) => write!(f, "{:?}", *v), - AttrValue::Bool(v) => write!(f, "{:?}", *v), - AttrValue::Utf8(v) => write!(f, "{:?}", *v), - AttrValue::Custom => write!(f, "custom-attribute"), - }?; - Ok(()) - } -} - -impl ToTLV for AttrValue { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - // What is the time complexity of such long match statements? - match self { - AttrValue::Bool(v) => tw.bool(tag_type, *v), - AttrValue::Uint8(v) => tw.u8(tag_type, *v), - AttrValue::Uint16(v) => tw.u16(tag_type, *v), - AttrValue::Uint32(v) => tw.u32(tag_type, *v), - AttrValue::Uint64(v) => tw.u64(tag_type, *v), - AttrValue::Utf8(v) => tw.utf8(tag_type, v.as_bytes()), - _ => { - error!("Attribute type not yet supported"); - Err(Error::AttributeNotFound) - } - } - } -} - -impl AttrValue { - pub fn update_from_tlv(&mut self, tr: &TLVElement) -> Result<(), Error> { - match self { - AttrValue::Bool(v) => *v = tr.bool()?, - AttrValue::Uint8(v) => *v = tr.u8()?, - AttrValue::Uint16(v) => *v = tr.u16()?, - AttrValue::Uint32(v) => *v = tr.u32()?, - AttrValue::Uint64(v) => *v = tr.u64()?, - _ => { - error!("Attribute type not yet supported"); - return Err(Error::AttributeNotFound); - } - } - Ok(()) - } -} - #[derive(Debug, Clone)] pub struct Attribute { - pub(super) id: AttrId, - pub(super) value: AttrValue, - pub(super) quality: Quality, - pub(super) access: Access, -} - -impl Default for Attribute { - fn default() -> Attribute { - Attribute { - id: 0, - value: AttrValue::Bool(true), - quality: Default::default(), - access: Default::default(), - } - } + pub id: AttrId, + pub quality: Quality, + pub access: Access, } impl Attribute { - pub fn new(id: AttrId, value: AttrValue, access: Access, quality: Quality) -> Self { - Attribute { + pub const fn new(id: AttrId, access: Access, quality: Quality) -> Self { + Self { id, - value, access, quality, } } - pub fn set_value(&mut self, value: AttrValue) -> Result<(), Error> { - if !self.quality.contains(Quality::FIXED) { - self.value = value; - Ok(()) - } else { - Err(Error::Invalid) - } + pub fn is_system(&self) -> bool { + Self::is_system_attr(self.id) } pub fn is_system_attr(attr_id: AttrId) -> bool { @@ -194,9 +104,9 @@ impl Attribute { } } -impl std::fmt::Display for Attribute { +impl core::fmt::Display for Attribute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {:?}", self.id, self.value) + write!(f, "{}", self.id) } } diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 7ca8350..90c6835 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -15,25 +15,31 @@ * limitations under the License. */ -use crate::{ - acl::AccessReq, - data_model::objects::{Access, AttrValue, Attribute, EncodeValue, Quality}, - error::*, - interaction_model::{command::CommandReq, core::IMStatusCode}, - // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{Nullable, TLVElement, TLVWriter, TagType}, -}; use log::error; -use num_derive::FromPrimitive; -use rand::Rng; -use std::fmt::{self, Debug}; +use strum::FromRepr; -use super::{AttrId, ClusterId, Encoder}; +use crate::{ + acl::{AccessReq, Accessor}, + attribute_enum, + data_model::objects::*, + error::Error, + interaction_model::{ + core::IMStatusCode, + messages::{ + ib::{AttrPath, AttrStatus, CmdPath, CmdStatus}, + GenericPath, + }, + }, + // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer + tlv::{Nullable, TLVWriter, TagType}, +}; +use core::{ + convert::TryInto, + fmt::{self, Debug}, +}; -pub const ATTRS_PER_CLUSTER: usize = 10; -pub const CMDS_PER_CLUSTER: usize = 8; - -#[derive(FromPrimitive, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, FromRepr)] +#[repr(u16)] pub enum GlobalElements { _ClusterRevision = 0xFFFD, FeatureMap = 0xFFFC, @@ -44,297 +50,308 @@ pub enum GlobalElements { FabricIndex = 0xFE, } +attribute_enum!(GlobalElements); + +pub const FEATURE_MAP: Attribute = + Attribute::new(GlobalElements::FeatureMap as _, Access::RV, Quality::NONE); + +pub const ATTRIBUTE_LIST: Attribute = Attribute::new( + GlobalElements::AttributeList as _, + Access::RV, + Quality::NONE, +); + // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // methods? /// The Attribute Details structure records the details about the attribute under consideration. -/// Typically this structure is progressively built as we proceed through the request processing. -pub struct AttrDetails { - /// Fabric Filtering Activated - pub fab_filter: bool, - /// The current Fabric Index - pub fab_idx: u8, - /// List Index, if any - pub list_index: Option>, +pub struct AttrDetails<'a> { + pub node: &'a Node<'a>, + /// The actual endpoint ID + pub endpoint_id: EndptId, + /// The actual cluster ID + pub cluster_id: ClusterId, /// The actual attribute ID pub attr_id: AttrId, + /// List Index, if any + pub list_index: Option>, + /// The current Fabric Index + pub fab_idx: u8, + /// Fabric Filtering Activated + pub fab_filter: bool, + pub dataver: Option, + pub wildcard: bool, } -impl AttrDetails { - pub fn new(fab_idx: u8, fab_filter: bool) -> Self { +impl<'a> AttrDetails<'a> { + pub fn is_system(&self) -> bool { + Attribute::is_system_attr(self.attr_id) + } + + pub fn path(&self) -> AttrPath { + AttrPath { + endpoint: Some(self.endpoint_id), + cluster: Some(self.cluster_id), + attr: Some(self.attr_id), + list_index: self.list_index, + ..Default::default() + } + } + + pub fn status(&self, status: IMStatusCode) -> Result, Error> { + if self.should_report(status) { + Ok(Some(AttrStatus::new( + &GenericPath { + endpoint: Some(self.endpoint_id), + cluster: Some(self.cluster_id), + leaf: Some(self.attr_id as _), + }, + status, + 0, + ))) + } else { + Ok(None) + } + } + + fn should_report(&self, status: IMStatusCode) -> bool { + !self.wildcard + || !matches!( + status, + IMStatusCode::UnsupportedEndpoint + | IMStatusCode::UnsupportedCluster + | IMStatusCode::UnsupportedAttribute + | IMStatusCode::UnsupportedCommand + | IMStatusCode::UnsupportedAccess + | IMStatusCode::UnsupportedRead + | IMStatusCode::UnsupportedWrite + | IMStatusCode::DataVersionMismatch + ) + } +} + +pub struct CmdDetails<'a> { + pub node: &'a Node<'a>, + pub endpoint_id: EndptId, + pub cluster_id: ClusterId, + pub cmd_id: CmdId, + pub wildcard: bool, +} + +impl<'a> CmdDetails<'a> { + pub fn path(&self) -> CmdPath { + CmdPath::new( + Some(self.endpoint_id), + Some(self.cluster_id), + Some(self.cmd_id), + ) + } + + pub fn success(&self, tracker: &CmdDataTracker) -> Option { + if tracker.needs_status() { + self.status(IMStatusCode::Success) + } else { + None + } + } + + pub fn status(&self, status: IMStatusCode) -> Option { + if self.should_report(status) { + Some(CmdStatus::new( + CmdPath::new( + Some(self.endpoint_id), + Some(self.cluster_id), + Some(self.cmd_id), + ), + status, + 0, + )) + } else { + None + } + } + + fn should_report(&self, status: IMStatusCode) -> bool { + !self.wildcard + || !matches!( + status, + IMStatusCode::UnsupportedEndpoint + | IMStatusCode::UnsupportedCluster + | IMStatusCode::UnsupportedAttribute + | IMStatusCode::UnsupportedCommand + | IMStatusCode::UnsupportedAccess + | IMStatusCode::UnsupportedRead + | IMStatusCode::UnsupportedWrite + ) + } +} + +#[derive(Debug, Clone)] +pub struct Cluster<'a> { + pub id: ClusterId, + pub feature_map: u32, + pub attributes: &'a [Attribute], + pub commands: &'a [CmdId], +} + +impl<'a> Cluster<'a> { + pub const fn new( + id: ClusterId, + feature_map: u32, + attributes: &'a [Attribute], + commands: &'a [CmdId], + ) -> Self { Self { - fab_filter, - fab_idx, - list_index: None, - attr_id: 0, - } - } -} - -pub trait ClusterType { - // TODO: 5 methods is going to be quite expensive for vtables of all the clusters - fn base(&self) -> &Cluster; - fn base_mut(&mut self) -> &mut Cluster; - fn read_custom_attribute(&self, _encoder: &mut dyn Encoder, _attr: &AttrDetails) {} - - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req.cmd.path.leaf.map(|a| a as u16); - println!("Received command: {:?}", cmd); - - Err(IMStatusCode::UnsupportedCommand) - } - - /// Write an attribute - /// - /// Note that if this method is defined, you must handle the write for all the attributes. Even those - /// that are not 'custom'. This is different from how you handle the read_custom_attribute() method. - /// The reason for this being, you may want to handle an attribute write request even though it is a - /// standard attribute like u16, u32 etc. - /// - /// If you wish to update the standard attribute in the data model database, you must call the - /// write_attribute_from_tlv() method from the base cluster, as is shown here in the default case - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - self.base_mut().write_attribute_from_tlv(attr.attr_id, data) - } -} - -pub struct Cluster { - pub(super) id: ClusterId, - attributes: Vec, - data_ver: u32, -} - -impl Cluster { - pub fn new(id: ClusterId) -> Result { - let mut c = Cluster { id, - attributes: Vec::with_capacity(ATTRS_PER_CLUSTER), - data_ver: rand::thread_rng().gen_range(0..0xFFFFFFFF), - }; - c.add_default_attributes()?; - Ok(c) - } - - pub fn id(&self) -> ClusterId { - self.id - } - - pub fn get_dataver(&self) -> u32 { - self.data_ver - } - - pub fn set_feature_map(&mut self, map: u32) -> Result<(), Error> { - self.write_attribute_raw(GlobalElements::FeatureMap as u16, AttrValue::Uint32(map)) - .map_err(|_| Error::Invalid)?; - Ok(()) - } - - fn add_default_attributes(&mut self) -> Result<(), Error> { - // Default feature map is 0 - self.add_attribute(Attribute::new( - GlobalElements::FeatureMap as u16, - AttrValue::Uint32(0), - Access::RV, - Quality::NONE, - ))?; - - self.add_attribute(Attribute::new( - GlobalElements::AttributeList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - )) - } - - pub fn add_attributes(&mut self, attrs: &[Attribute]) -> Result<(), Error> { - if self.attributes.len() + attrs.len() <= self.attributes.capacity() { - self.attributes.extend_from_slice(attrs); - Ok(()) - } else { - Err(Error::NoSpace) + feature_map, + attributes, + commands, } } - pub fn add_attribute(&mut self, attr: Attribute) -> Result<(), Error> { - if self.attributes.len() < self.attributes.capacity() { - self.attributes.push(attr); - Ok(()) - } else { - Err(Error::NoSpace) - } + pub(crate) fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: EndptId, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.attributes + .iter() + .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true)) + .filter(move |attribute| { + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(attribute.id as _)), + if write { Access::WRITE } else { Access::READ }, + ); + self.check_attr_access(&mut access_req, attribute.access) + .is_ok() + }) + .map(|attribute| attribute.id) } - fn get_attribute_index(&self, attr_id: AttrId) -> Option { - self.attributes.iter().position(|c| c.id == attr_id) + pub fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: EndptId, + cmd: Option, + ) -> impl Iterator + 'm { + self.commands + .iter() + .filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true)) + .filter(move |id| { + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(**id as _)), + Access::WRITE, + ); + self.check_cmd_access(&mut access_req).is_ok() + }) + .copied() } - fn get_attribute(&self, attr_id: AttrId) -> Result<&Attribute, Error> { - let index = self - .get_attribute_index(attr_id) - .ok_or(Error::AttributeNotFound)?; - Ok(&self.attributes[index]) - } - - fn get_attribute_mut(&mut self, attr_id: AttrId) -> Result<&mut Attribute, Error> { - let index = self - .get_attribute_index(attr_id) - .ok_or(Error::AttributeNotFound)?; - Ok(&mut self.attributes[index]) - } - - // Returns a slice of attribute, with either a single attribute or all (wildcard) - pub fn get_wildcard_attribute( + pub(crate) fn check_attribute( &self, - attribute: Option, - ) -> Result<(&[Attribute], bool), IMStatusCode> { - if let Some(a) = attribute { - if let Some(i) = self.get_attribute_index(a) { - Ok((&self.attributes[i..i + 1], false)) + accessor: &Accessor, + ep: EndptId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + let attribute = self + .attributes + .iter() + .find(|attribute| attribute.id == attr) + .ok_or(IMStatusCode::UnsupportedAttribute)?; + + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(attr as _)), + if write { Access::WRITE } else { Access::READ }, + ); + + self.check_attr_access(&mut access_req, attribute.access) + } + + pub(crate) fn check_command( + &self, + accessor: &Accessor, + ep: EndptId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.commands + .iter() + .find(|id| **id == cmd) + .ok_or(IMStatusCode::UnsupportedCommand)?; + + let mut access_req = AccessReq::new( + accessor, + GenericPath::new(Some(ep), Some(self.id), Some(cmd as _)), + Access::WRITE, + ); + + self.check_cmd_access(&mut access_req) + } + + fn check_attr_access( + &self, + access_req: &mut AccessReq, + target_perms: Access, + ) -> Result<(), IMStatusCode> { + if !target_perms.contains(access_req.operation()) { + Err(if matches!(access_req.operation(), Access::WRITE) { + IMStatusCode::UnsupportedWrite } else { - Err(IMStatusCode::UnsupportedAttribute) + IMStatusCode::UnsupportedRead + })?; + } + + access_req.set_target_perms(target_perms); + if access_req.allow() { + Ok(()) + } else { + Err(IMStatusCode::UnsupportedAccess) + } + } + + fn check_cmd_access(&self, access_req: &mut AccessReq) -> Result<(), IMStatusCode> { + access_req.set_target_perms( + Access::WRITE + .union(Access::NEED_OPERATE) + .union(Access::NEED_MANAGE) + .union(Access::NEED_ADMIN), + ); // TODO + if access_req.allow() { + Ok(()) + } else { + Err(IMStatusCode::UnsupportedAccess) + } + } + + pub fn read(&self, attr: AttrId, mut writer: AttrDataWriter) -> Result<(), Error> { + match attr.try_into()? { + GlobalElements::AttributeList => { + self.encode_attribute_ids(AttrDataWriter::TAG, &mut writer)?; + writer.complete() } - } else { - Ok((&self.attributes[..], true)) - } - } - - pub fn read_attribute( - c: &dyn ClusterType, - access_req: &mut AccessReq, - encoder: &mut dyn Encoder, - attr: &AttrDetails, - ) { - let mut error = IMStatusCode::Success; - let base = c.base(); - let a = if let Ok(a) = base.get_attribute(attr.attr_id) { - a - } else { - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); - return; - }; - - if !a.access.contains(Access::READ) { - error = IMStatusCode::UnsupportedRead; - } - - access_req.set_target_perms(a.access); - if !access_req.allow() { - error = IMStatusCode::UnsupportedAccess; - } - - if error != IMStatusCode::Success { - encoder.encode_status(error, 0); - } else if Attribute::is_system_attr(attr.attr_id) { - c.base().read_system_attribute(encoder, a) - } else if a.value != AttrValue::Custom { - encoder.encode(EncodeValue::Value(&a.value)) - } else { - c.read_custom_attribute(encoder, attr) - } - } - - fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) { - let _ = tw.start_array(tag); - for a in &self.attributes { - let _ = tw.u16(TagType::Anonymous, a.id); - } - let _ = tw.end_container(); - } - - fn read_system_attribute(&self, encoder: &mut dyn Encoder, attr: &Attribute) { - let global_attr: Option = num::FromPrimitive::from_u16(attr.id); - if let Some(global_attr) = global_attr { - match global_attr { - GlobalElements::AttributeList => { - encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_attribute_ids(tag, tw) - })); - return; - } - GlobalElements::FeatureMap => { - encoder.encode(EncodeValue::Value(&attr.value)); - return; - } - _ => { - error!("This attribute not yet handled {:?}", global_attr); - } + GlobalElements::FeatureMap => writer.set(self.feature_map), + other => { + error!("This attribute is not yet handled {:?}", other); + Err(Error::AttributeNotFound) } } - encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0) } - pub fn read_attribute_raw(&self, attr_id: AttrId) -> Result<&AttrValue, IMStatusCode> { - let a = self - .get_attribute(attr_id) - .map_err(|_| IMStatusCode::UnsupportedAttribute)?; - Ok(&a.value) - } - - pub fn write_attribute( - c: &mut dyn ClusterType, - access_req: &mut AccessReq, - data: &TLVElement, - attr: &AttrDetails, - ) -> Result<(), IMStatusCode> { - let base = c.base_mut(); - let a = if let Ok(a) = base.get_attribute_mut(attr.attr_id) { - a - } else { - return Err(IMStatusCode::UnsupportedAttribute); - }; - - if !a.access.contains(Access::WRITE) { - return Err(IMStatusCode::UnsupportedWrite); + fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_array(tag)?; + for a in self.attributes { + tw.u16(TagType::Anonymous, a.id)?; } - access_req.set_target_perms(a.access); - if !access_req.allow() { - return Err(IMStatusCode::UnsupportedAccess); - } - - c.write_attribute(attr, data) - } - - pub fn write_attribute_from_tlv( - &mut self, - attr_id: AttrId, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - let a = self.get_attribute_mut(attr_id)?; - if a.value != AttrValue::Custom { - let mut value = a.value.clone(); - value - .update_from_tlv(data) - .map_err(|_| IMStatusCode::Failure)?; - a.set_value(value) - .map(|_| { - self.cluster_changed(); - }) - .map_err(|_| IMStatusCode::UnsupportedWrite) - } else { - Err(IMStatusCode::UnsupportedAttribute) - } - } - - pub fn write_attribute_raw(&mut self, attr_id: AttrId, value: AttrValue) -> Result<(), Error> { - let a = self.get_attribute_mut(attr_id)?; - a.set_value(value).map(|_| { - self.cluster_changed(); - }) - } - - /// This method must be called for any changes to the data model - /// Currently this only increments the data version, but we can reuse the same - /// for raising events too - pub fn cluster_changed(&mut self) { - self.data_ver = self.data_ver.wrapping_add(1); + tw.end_container() } } -impl std::fmt::Display for Cluster { +impl<'a> core::fmt::Display for Cluster<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "id:{}, ", self.id)?; write!(f, "attrs[")?; diff --git a/matter/src/data_model/objects/dataver.rs b/matter/src/data_model/objects/dataver.rs new file mode 100644 index 0000000..fc062be --- /dev/null +++ b/matter/src/data_model/objects/dataver.rs @@ -0,0 +1,55 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + use crate::utils::rand::Rand; + +pub struct Dataver { + ver: u32, + changed: bool, +} + +impl Dataver { + pub fn new(rand: Rand) -> Self { + let mut buf = [0; 4]; + rand(&mut buf); + + Self { + ver: u32::from_be_bytes(buf), + changed: false, + } + } + + pub fn get(&self) -> u32 { + self.ver + } + + pub fn changed(&mut self) -> u32 { + (self.ver, _) = self.ver.overflowing_add(1); + self.changed = true; + + self.get() + } + + pub fn consume_change(&mut self, change: T) -> Option { + if self.changed { + self.changed = false; + Some(change) + } else { + None + } + } +} diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index d565316..39d2ba6 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -15,17 +15,26 @@ * limitations under the License. */ -use std::fmt::{Debug, Formatter}; +use core::fmt::{Debug, Formatter}; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use crate::interaction_model::core::{IMStatusCode, Transaction}; +use crate::interaction_model::messages::ib::{ + AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, +}; +use crate::tlv::UtfStr; use crate::{ error::Error, - interaction_model::core::IMStatusCode, + interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, }; use log::error; +use super::{AttrDetails, CmdDetails, Handler}; + // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer -// may have already started encoding the 'success' headers, we might not to manage +// may have already started encoding the 'success' headers, we might not want to manage // the tw.rewind() in that case, if we add this support pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); @@ -78,7 +87,7 @@ impl<'a> PartialEq for EncodeValue<'a> { } impl<'a> Debug for EncodeValue<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { match *self { EncodeValue::Closure(_) => write!(f, "Contains closure"), EncodeValue::Tlv(t) => write!(f, "{:?}", t), @@ -107,17 +116,454 @@ impl<'a> FromTLV<'a> for EncodeValue<'a> { } } -/// An object that can encode EncodeValue into the necessary hierarchical structure -/// as expected by the Interaction Model -pub trait Encoder { - /// Encode a given value - fn encode(&mut self, value: EncodeValue); - /// Encode a status report - fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16); +pub struct AttrDataEncoder<'a, 'b, 'c> { + dataver_filter: Option, + path: AttrPath, + tw: &'a mut TLVWriter<'b, 'c>, } -#[derive(ToTLV, Copy, Clone)] -pub struct DeviceType { - pub dtype: u16, - pub drev: u16, +impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { + pub fn handle_read( + item: Result, + handler: &T, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok(attr) => { + let encoder = AttrDataEncoder::new(&attr, tw); + + match handler.read(&attr, encoder) { + Ok(()) => None, + Err(error) => attr.status(error.into())?, + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn handle_write( + item: Result<(AttrDetails, TLVElement), AttrStatus>, + handler: &mut T, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, + }, + Err(status) => Some(status), + }; + + if let Some(status) = status { + status.to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_read_async( + item: Result, AttrStatus>, + handler: &T, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok(attr) => { + let encoder = AttrDataEncoder::new(&attr, tw); + + match handler.read(&attr, encoder).await { + Ok(()) => None, + Err(error) => attr.status(error.into())?, + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_write_async( + item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>, + handler: &mut T, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok((attr, data)) => match handler + .write(&attr, AttrData::new(attr.dataver, &data)) + .await + { + Ok(()) => attr.status(IMStatusCode::Success)?, + Err(error) => attr.status(error.into())?, + }, + Err(status) => Some(status), + }; + + if let Some(status) = status { + AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self { + Self { + dataver_filter: attr.dataver, + path: attr.path(), + tw, + } + } + + pub fn with_dataver(self, dataver: u32) -> Result>, Error> { + if self + .dataver_filter + .map(|dataver_filter| dataver_filter != dataver) + .unwrap_or(true) + { + let mut writer = AttrDataWriter::new(self.tw); + + writer.start_struct(TagType::Anonymous)?; + writer.start_struct(TagType::Context(AttrRespTag::Data as _))?; + writer.u32(TagType::Context(AttrDataTag::DataVer as _), dataver)?; + self.path + .to_tlv(&mut writer, TagType::Context(AttrDataTag::Path as _))?; + + Ok(Some(writer)) + } else { + Ok(None) + } + } +} + +pub struct AttrDataWriter<'a, 'b, 'c> { + tw: &'a mut TLVWriter<'b, 'c>, + anchor: usize, + completed: bool, +} + +impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> { + pub const TAG: TagType = TagType::Context(AttrDataTag::Data as _); + + fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self { + let anchor = tw.get_tail(); + + Self { + tw, + anchor, + completed: false, + } + } + + pub fn set(self, value: T) -> Result<(), Error> { + value.to_tlv(self.tw, Self::TAG)?; + self.complete() + } + + pub fn complete(mut self) -> Result<(), Error> { + self.tw.end_container()?; + self.tw.end_container()?; + + self.completed = true; + + Ok(()) + } + + fn reset(&mut self) { + self.tw.rewind_to(self.anchor); + } +} + +impl<'a, 'b, 'c> Drop for AttrDataWriter<'a, 'b, 'c> { + fn drop(&mut self) { + if !self.completed { + self.reset(); + } + } +} + +impl<'a, 'b, 'c> Deref for AttrDataWriter<'a, 'b, 'c> { + type Target = TLVWriter<'b, 'c>; + + fn deref(&self) -> &Self::Target { + self.tw + } +} + +impl<'a, 'b, 'c> DerefMut for AttrDataWriter<'a, 'b, 'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tw + } +} + +pub struct AttrData<'a> { + for_dataver: Option, + data: &'a TLVElement<'a>, +} + +impl<'a> AttrData<'a> { + pub fn new(for_dataver: Option, data: &'a TLVElement<'a>) -> Self { + Self { for_dataver, data } + } + + pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> { + if let Some(req_dataver) = self.for_dataver { + if req_dataver != dataver { + return Err(Error::DataVersionMismatch); + } + } + + Ok(self.data) + } +} + +#[derive(Default)] +pub struct CmdDataTracker { + skip_status: bool, +} + +impl CmdDataTracker { + pub const fn new() -> Self { + Self { skip_status: false } + } + + pub(crate) fn complete(&mut self) { + self.skip_status = true; + } + + pub fn needs_status(&self) -> bool { + !self.skip_status + } +} + +pub struct CmdDataEncoder<'a, 'b, 'c> { + tracker: &'a mut CmdDataTracker, + path: CmdPath, + tw: &'a mut TLVWriter<'b, 'c>, +} + +impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { + pub fn handle( + item: Result<(CmdDetails, TLVElement), CmdStatus>, + handler: &mut T, + transaction: &mut Transaction, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + let status = match item { + Ok((cmd, data)) => { + let mut tracker = CmdDataTracker::new(); + let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + + match handler.invoke(transaction, &cmd, &data, encoder) { + Ok(()) => cmd.success(&tracker), + Err(error) => cmd.status(error.into()), + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn handle_async( + item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>, + handler: &mut T, + transaction: &mut Transaction<'_, '_>, + tw: &mut TLVWriter<'_, '_>, + ) -> Result<(), Error> { + let status = match item { + Ok((cmd, data)) => { + let mut tracker = CmdDataTracker::new(); + let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw); + + match handler.invoke(transaction, &cmd, &data, encoder).await { + Ok(()) => cmd.success(&tracker), + Err(error) => cmd.status(error.into()), + } + } + Err(status) => Some(status), + }; + + if let Some(status) = status { + InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + } + + Ok(()) + } + + pub fn new( + cmd: &CmdDetails, + tracker: &'a mut CmdDataTracker, + tw: &'a mut TLVWriter<'b, 'c>, + ) -> Self { + Self { + tracker, + path: cmd.path(), + tw, + } + } + + pub fn with_command(mut self, cmd: u16) -> Result, Error> { + let mut writer = CmdDataWriter::new(self.tracker, self.tw); + + writer.start_struct(TagType::Anonymous)?; + writer.start_struct(TagType::Context(InvRespTag::Cmd as _))?; + + self.path.path.leaf = Some(cmd as _); + self.path + .to_tlv(&mut writer, TagType::Context(CmdDataTag::Path as _))?; + + Ok(writer) + } +} + +pub struct CmdDataWriter<'a, 'b, 'c> { + tracker: &'a mut CmdDataTracker, + tw: &'a mut TLVWriter<'b, 'c>, + anchor: usize, + completed: bool, +} + +impl<'a, 'b, 'c> CmdDataWriter<'a, 'b, 'c> { + pub const TAG: TagType = TagType::Context(CmdDataTag::Data as _); + + fn new(tracker: &'a mut CmdDataTracker, tw: &'a mut TLVWriter<'b, 'c>) -> Self { + let anchor = tw.get_tail(); + + Self { + tracker, + tw, + anchor, + completed: false, + } + } + + pub fn set(self, value: T) -> Result<(), Error> { + value.to_tlv(self.tw, Self::TAG)?; + self.complete() + } + + pub fn complete(mut self) -> Result<(), Error> { + self.tw.end_container()?; + self.tw.end_container()?; + + self.completed = true; + self.tracker.complete(); + + Ok(()) + } + + fn reset(&mut self) { + self.tw.rewind_to(self.anchor); + } +} + +impl<'a, 'b, 'c> Drop for CmdDataWriter<'a, 'b, 'c> { + fn drop(&mut self) { + if !self.completed { + self.reset(); + } + } +} + +impl<'a, 'b, 'c> Deref for CmdDataWriter<'a, 'b, 'c> { + type Target = TLVWriter<'b, 'c>; + + fn deref(&self) -> &Self::Target { + self.tw + } +} + +impl<'a, 'b, 'c> DerefMut for CmdDataWriter<'a, 'b, 'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tw + } +} + +#[derive(Copy, Clone, Debug)] +pub struct AttrType(PhantomData T>); + +impl AttrType { + pub const fn new() -> Self { + Self(PhantomData) + } + + pub fn encode(&self, writer: AttrDataWriter, value: T) -> Result<(), Error> + where + T: ToTLV, + { + writer.set(value) + } + + pub fn decode<'a>(&self, data: &'a TLVElement) -> Result + where + T: FromTLV<'a>, + { + T::from_tlv(data) + } +} + +impl Default for AttrType { + fn default() -> Self { + Self::new() + } +} + +#[derive(Copy, Clone, Debug, Default)] +pub struct AttrUtfType; + +impl AttrUtfType { + pub const fn new() -> Self { + Self + } + + pub fn encode(&self, writer: AttrDataWriter, value: &str) -> Result<(), Error> { + writer.set(UtfStr::new(value.as_bytes())) + } + + pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<&'a str, IMStatusCode> { + data.str().map_err(|_| IMStatusCode::InvalidDataType) + } +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! attribute_enum { + ($en:ty) => { + impl core::convert::TryFrom<$crate::data_model::objects::AttrId> for $en { + type Error = $crate::error::Error; + + fn try_from(id: $crate::data_model::objects::AttrId) -> Result { + <$en>::from_repr(id).ok_or($crate::error::Error::AttributeNotFound) + } + } + }; +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! command_enum { + ($en:ty) => { + impl core::convert::TryFrom<$crate::data_model::objects::CmdId> for $en { + type Error = $crate::error::Error; + + fn try_from(id: $crate::data_model::objects::CmdId) -> Result { + <$en>::from_repr(id).ok_or($crate::error::Error::CommandNotFound) + } + } + }; } diff --git a/matter/src/data_model/objects/endpoint.rs b/matter/src/data_model/objects/endpoint.rs index 466e7a6..d0a4fdd 100644 --- a/matter/src/data_model/objects/endpoint.rs +++ b/matter/src/data_model/objects/endpoint.rs @@ -15,104 +15,91 @@ * limitations under the License. */ -use crate::{data_model::objects::ClusterType, error::*, interaction_model::core::IMStatusCode}; +use crate::{acl::Accessor, interaction_model::core::IMStatusCode}; -use std::fmt; +use core::fmt; -use super::{ClusterId, DeviceType}; +use super::{AttrId, Cluster, ClusterId, CmdId, DeviceType, EndptId}; -pub const CLUSTERS_PER_ENDPT: usize = 9; - -pub struct Endpoint { - dev_type: DeviceType, - clusters: Vec>, +#[derive(Debug, Clone)] +pub struct Endpoint<'a> { + pub id: EndptId, + pub device_type: DeviceType, + pub clusters: &'a [Cluster<'a>], } -pub type BoxedClusters = [Box]; - -impl Endpoint { - pub fn new(dev_type: DeviceType) -> Result, Error> { - Ok(Box::new(Endpoint { - dev_type, - clusters: Vec::with_capacity(CLUSTERS_PER_ENDPT), - })) +impl<'a> Endpoint<'a> { + pub(crate) fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + cl: Option, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.match_clusters(cl).flat_map(move |cluster| { + cluster + .match_attributes(accessor, self.id, attr, write) + .map(move |attr| (cluster.id, attr)) + }) } - pub fn add_cluster(&mut self, cluster: Box) -> Result<(), Error> { - if self.clusters.len() < self.clusters.capacity() { - self.clusters.push(cluster); - Ok(()) - } else { - Err(Error::NoSpace) - } + pub(crate) fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + cl: Option, + cmd: Option, + ) -> impl Iterator + 'm { + self.match_clusters(cl).flat_map(move |cluster| { + cluster + .match_commands(accessor, self.id, cmd) + .map(move |cmd| (cluster.id, cmd)) + }) } - pub fn get_dev_type(&self) -> &DeviceType { - &self.dev_type - } - - fn get_cluster_index(&self, cluster_id: ClusterId) -> Option { - self.clusters.iter().position(|c| c.base().id == cluster_id) - } - - pub fn get_cluster(&self, cluster_id: ClusterId) -> Result<&dyn ClusterType, Error> { - let index = self - .get_cluster_index(cluster_id) - .ok_or(Error::ClusterNotFound)?; - Ok(self.clusters[index].as_ref()) - } - - pub fn get_cluster_mut( - &mut self, - cluster_id: ClusterId, - ) -> Result<&mut dyn ClusterType, Error> { - let index = self - .get_cluster_index(cluster_id) - .ok_or(Error::ClusterNotFound)?; - Ok(self.clusters[index].as_mut()) - } - - // Returns a slice of clusters, with either a single cluster or all (wildcard) - pub fn get_wildcard_clusters( + pub(crate) fn check_attribute( &self, - cluster: Option, - ) -> Result<(&BoxedClusters, bool), IMStatusCode> { - if let Some(c) = cluster { - if let Some(i) = self.get_cluster_index(c) { - Ok((&self.clusters[i..i + 1], false)) - } else { - Err(IMStatusCode::UnsupportedCluster) - } - } else { - Ok((self.clusters.as_slice(), true)) - } + accessor: &Accessor, + cl: ClusterId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + self.check_cluster(cl) + .and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write)) } - // Returns a slice of clusters, with either a single cluster or all (wildcard) - pub fn get_wildcard_clusters_mut( - &mut self, - cluster: Option, - ) -> Result<(&mut BoxedClusters, bool), IMStatusCode> { - if let Some(c) = cluster { - if let Some(i) = self.get_cluster_index(c) { - Ok((&mut self.clusters[i..i + 1], false)) - } else { - Err(IMStatusCode::UnsupportedCluster) - } - } else { - Ok((&mut self.clusters[..], true)) - } + pub(crate) fn check_command( + &self, + accessor: &Accessor, + cl: ClusterId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.check_cluster(cl) + .and_then(|cluster| cluster.check_command(accessor, self.id, cmd)) + } + + fn match_clusters(&self, cl: Option) -> impl Iterator + '_ { + self.clusters + .iter() + .filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true)) + } + + fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> { + self.clusters + .iter() + .find(|cluster| cluster.id == cl) + .ok_or(IMStatusCode::UnsupportedCluster) } } -impl std::fmt::Display for Endpoint { +impl<'a> core::fmt::Display for Endpoint<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "clusters:[")?; let mut comma = ""; - for element in self.clusters.iter() { - write!(f, "{} {{ {} }}", comma, element.base())?; + for cluster in self.clusters { + write!(f, "{} {{ {} }}", comma, cluster)?; comma = ", "; } + write!(f, "]") } } diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs new file mode 100644 index 0000000..052d690 --- /dev/null +++ b/matter/src/data_model/objects/handler.rs @@ -0,0 +1,350 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::{error::Error, interaction_model::core::Transaction, tlv::TLVElement}; + +use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; + +pub trait ChangeNotifier { + fn consume_change(&mut self) -> Option; +} + +pub trait Handler { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; + + fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + + fn invoke( + &mut self, + _transaction: &mut Transaction, + _cmd: &CmdDetails, + _data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + Err(Error::CommandNotFound) + } +} + +impl Handler for &mut T +where + T: Handler, +{ + fn read<'a>(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + (**self).read(attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + (**self).write(attr, data) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + (**self).invoke(transaction, cmd, data, encoder) + } +} + +pub trait NonBlockingHandler: Handler {} + +impl NonBlockingHandler for &mut T where T: NonBlockingHandler {} + +pub struct EmptyHandler; + +impl EmptyHandler { + pub const fn chain( + self, + handler_endpoint: u16, + handler_cluster: u32, + handler: H, + ) -> ChainedHandler { + ChainedHandler { + handler_endpoint, + handler_cluster, + handler, + next: self, + } + } +} + +impl Handler for EmptyHandler { + fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } +} + +impl NonBlockingHandler for EmptyHandler {} + +impl ChangeNotifier<(u16, u32)> for EmptyHandler { + fn consume_change(&mut self) -> Option<(u16, u32)> { + None + } +} + +pub struct ChainedHandler { + pub handler_endpoint: u16, + pub handler_cluster: u32, + pub handler: H, + pub next: T, +} + +impl ChainedHandler { + pub const fn chain

( + self, + handler_endpoint: u16, + handler_cluster: u32, + handler: H2, + ) -> ChainedHandler { + ChainedHandler { + handler_endpoint, + handler_cluster, + handler, + next: self, + } + } +} + +impl Handler for ChainedHandler +where + H: Handler, + T: Handler, +{ + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { + self.handler.read(attr, encoder) + } else { + self.next.read(attr, encoder) + } + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id { + self.handler.write(attr, data) + } else { + self.next.write(attr, data) + } + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { + self.handler.invoke(transaction, cmd, data, encoder) + } else { + self.next.invoke(transaction, cmd, data, encoder) + } + } +} + +impl NonBlockingHandler for ChainedHandler +where + H: NonBlockingHandler, + T: NonBlockingHandler, +{ +} + +impl ChangeNotifier<(u16, u32)> for ChainedHandler +where + H: ChangeNotifier<()>, + T: ChangeNotifier<(u16, u32)>, +{ + fn consume_change(&mut self) -> Option<(u16, u32)> { + if self.handler.consume_change().is_some() { + Some((self.handler_endpoint, self.handler_cluster)) + } else { + self.next.consume_change() + } + } +} + +#[allow(unused_macros)] +#[macro_export] +macro_rules! handler_chain_type { + ($h:ty) => { + $crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler> + }; + ($h1:ty, $($rest:ty),+) => { + $crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)> + }; +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, + error::Error, + interaction_model::core::Transaction, + tlv::TLVElement, + }; + + use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler}; + + pub trait AsyncHandler { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error>; + + async fn write<'a>( + &'a mut self, + _attr: &'a AttrDetails<'_>, + _data: AttrData<'a>, + ) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + + async fn invoke<'a>( + &'a mut self, + _transaction: &'a mut Transaction<'_, '_>, + _cmd: &'a CmdDetails<'_>, + _data: &'a TLVElement<'_>, + _encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Err(Error::CommandNotFound) + } + } + + impl AsyncHandler for &mut T + where + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).read(attr, encoder).await + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + (**self).write(attr, data).await + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + (**self).invoke(transaction, cmd, data, encoder).await + } + } + + pub struct Asyncify(pub T); + + impl AsyncHandler for Asyncify + where + T: NonBlockingHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Handler::read(&self.0, attr, encoder) + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + Handler::write(&mut self.0, attr, data) + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Handler::invoke(&mut self.0, transaction, cmd, data, encoder) + } + } + + impl AsyncHandler for EmptyHandler { + async fn read<'a>( + &'a self, + _attr: &'a AttrDetails<'_>, + _encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } + } + + impl AsyncHandler for ChainedHandler + where + H: AsyncHandler, + T: AsyncHandler, + { + async fn read<'a>( + &'a self, + attr: &'a AttrDetails<'_>, + encoder: AttrDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.read(attr, encoder).await + } else { + self.next.read(attr, encoder).await + } + } + + async fn write<'a>( + &'a mut self, + attr: &'a AttrDetails<'_>, + data: AttrData<'a>, + ) -> Result<(), Error> { + if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id + { + self.handler.write(attr, data).await + } else { + self.next.write(attr, data).await + } + } + + async fn invoke<'a>( + &'a mut self, + transaction: &'a mut Transaction<'_, '_>, + cmd: &'a CmdDetails<'_>, + data: &'a TLVElement<'_>, + encoder: CmdDataEncoder<'a, '_, '_>, + ) -> Result<(), Error> { + if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id { + self.handler.invoke(transaction, cmd, data, encoder).await + } else { + self.next.invoke(transaction, cmd, data, encoder).await + } + } + } +} diff --git a/matter/src/data_model/objects/mod.rs b/matter/src/data_model/objects/mod.rs index 2fb3aff..1bd326e 100644 --- a/matter/src/data_model/objects/mod.rs +++ b/matter/src/data_model/objects/mod.rs @@ -14,11 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -pub type EndptId = u16; -pub type ClusterId = u32; -pub type AttrId = u16; -pub type CmdId = u32; +use crate::error::Error; +use crate::tlv::{TLVWriter, TagType, ToTLV}; mod attribute; pub use attribute::*; @@ -37,3 +34,20 @@ pub use privilege::*; mod encoder; pub use encoder::*; + +mod handler; +pub use handler::*; + +mod dataver; +pub use dataver::*; + +pub type EndptId = u16; +pub type ClusterId = u32; +pub type AttrId = u16; +pub type CmdId = u32; + +#[derive(Debug, ToTLV, Copy, Clone)] +pub struct DeviceType { + pub dtype: u16, + pub drev: u16, +} diff --git a/matter/src/data_model/objects/node.rs b/matter/src/data_model/objects/node.rs index ba2f0b2..2eb1175 100644 --- a/matter/src/data_model/objects/node.rs +++ b/matter/src/data_model/objects/node.rs @@ -16,283 +16,379 @@ */ use crate::{ - data_model::objects::{ClusterType, Endpoint}, - error::*, - interaction_model::{core::IMStatusCode, messages::GenericPath}, + acl::Accessor, + data_model::objects::Endpoint, + interaction_model::{ + core::IMStatusCode, + messages::{ + ib::{AttrStatus, CmdStatus}, + msg::{InvReq, ReadReq, WriteReq}, + GenericPath, + }, + }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer + tlv::TLVElement, +}; +use core::{ + fmt, + iter::{once, Once}, }; -use std::fmt; -use super::{ClusterId, DeviceType, EndptId}; +use super::{AttrDetails, AttrId, ClusterId, CmdDetails, CmdId, EndptId}; -pub trait ChangeConsumer { - fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error>; +enum WildcardIter { + None, + Single(Once), + Wildcard(T), } -pub const ENDPTS_PER_ACC: usize = 3; +impl Iterator for WildcardIter +where + T: Iterator, +{ + type Item = E; -pub type BoxedEndpoints = [Option>]; - -#[derive(Default)] -pub struct Node { - endpoints: [Option>; ENDPTS_PER_ACC], - changes_cb: Option>, + fn next(&mut self) -> Option { + match self { + Self::None => None, + Self::Single(iter) => iter.next(), + Self::Wildcard(iter) => iter.next(), + } + } } -impl std::fmt::Display for Node { +#[derive(Debug, Clone)] +pub struct Node<'a> { + pub id: u16, + pub endpoints: &'a [Endpoint<'a>], +} + +impl<'a> Node<'a> { + pub fn read<'s, 'm>( + &'s self, + req: &'m ReadReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator> + 'm + where + 's: 'm, + { + if let Some(attr_requests) = req.attr_requests.as_ref() { + WildcardIter::Wildcard(attr_requests.iter().flat_map( + move |path| match self.expand_attr(accessor, path.to_gp(), false) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + + WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { + let dataver_filter = req + .dataver_filters + .as_ref() + .iter() + .flat_map(|array| array.iter()) + .find_map(|filter| { + (filter.path.endpoint == ep && filter.path.cluster == cl) + .then_some(filter.data_ver) + }); + + Ok(AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: req.fabric_filtered, + dataver: dataver_filter, + wildcard, + }) + })) + } + Err(err) => { + WildcardIter::Single(once(Err(AttrStatus::new(&path.to_gp(), err, 0)))) + } + }, + )) + } else { + WildcardIter::None + } + } + + pub fn write<'m>( + &'m self, + req: &'m WriteReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator), AttrStatus>> + 'm { + req.write_requests.iter().flat_map(move |attr_data| { + if attr_data.path.cluster.is_none() { + WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedCluster, + 0, + )))) + } else if attr_data.path.attr.is_none() { + WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedAttribute, + 0, + )))) + } else { + match self.expand_attr(accessor, attr_data.path.to_gp(), true) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + + WildcardIter::Wildcard(iter.map(move |(ep, cl, attr)| { + Ok(( + AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard, + }, + attr_data.data.unwrap_tlv().unwrap(), + )) + })) + } + Err(err) => WildcardIter::Single(once(Err(AttrStatus::new( + &attr_data.path.to_gp(), + err, + 0, + )))), + } + } + }) + } + + pub fn invoke<'m>( + &'m self, + req: &'m InvReq, + accessor: &'m Accessor<'m>, + ) -> impl Iterator), CmdStatus>> + 'm { + if let Some(inv_requests) = req.inv_requests.as_ref() { + WildcardIter::Wildcard(inv_requests.iter().flat_map(move |cmd_data| { + match self.expand_cmd(accessor, cmd_data.path.path) { + Ok(iter) => { + let wildcard = matches!(iter, WildcardIter::Wildcard(_)); + + WildcardIter::Wildcard(iter.map(move |(ep, cl, cmd)| { + Ok(( + CmdDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + cmd_id: cmd, + wildcard, + }, + cmd_data.data.unwrap_tlv().unwrap(), + )) + })) + } + Err(err) => { + WildcardIter::Single(once(Err(CmdStatus::new(cmd_data.path, err, 0)))) + } + } + })) + } else { + WildcardIter::None + } + } + + fn expand_attr<'m>( + &'m self, + accessor: &'m Accessor<'m>, + path: GenericPath, + write: bool, + ) -> Result< + WildcardIter< + impl Iterator + 'm, + (EndptId, ClusterId, AttrId), + >, + IMStatusCode, + > { + if path.is_wildcard() { + Ok(WildcardIter::Wildcard(self.match_attributes( + accessor, + path.endpoint, + path.cluster, + path.leaf.map(|leaf| leaf as u16), + write, + ))) + } else { + self.check_attribute( + accessor, + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap() as _, + write, + )?; + + Ok(WildcardIter::Single(once(( + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap() as _, + )))) + } + } + + fn expand_cmd<'m>( + &'m self, + accessor: &'m Accessor<'m>, + path: GenericPath, + ) -> Result< + WildcardIter< + impl Iterator + 'm, + (EndptId, ClusterId, CmdId), + >, + IMStatusCode, + > { + if path.is_wildcard() { + Ok(WildcardIter::Wildcard(self.match_commands( + accessor, + path.endpoint, + path.cluster, + path.leaf, + ))) + } else { + self.check_command( + accessor, + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap(), + )?; + + Ok(WildcardIter::Single(once(( + path.endpoint.unwrap(), + path.cluster.unwrap(), + path.leaf.unwrap(), + )))) + } + } + + fn match_attributes<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: Option, + cl: Option, + attr: Option, + write: bool, + ) -> impl Iterator + 'm { + self.match_endpoints(ep).flat_map(move |endpoint| { + endpoint + .match_attributes(accessor, cl, attr, write) + .map(move |(cl, attr)| (endpoint.id, cl, attr)) + }) + } + + fn match_commands<'m>( + &'m self, + accessor: &'m Accessor<'m>, + ep: Option, + cl: Option, + cmd: Option, + ) -> impl Iterator + 'm { + self.match_endpoints(ep).flat_map(move |endpoint| { + endpoint + .match_commands(accessor, cl, cmd) + .map(move |(cl, cmd)| (endpoint.id, cl, cmd)) + }) + } + + fn check_attribute( + &self, + accessor: &Accessor, + ep: EndptId, + cl: ClusterId, + attr: AttrId, + write: bool, + ) -> Result<(), IMStatusCode> { + self.check_endpoint(ep) + .and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write)) + } + + fn check_command( + &self, + accessor: &Accessor, + ep: EndptId, + cl: ClusterId, + cmd: CmdId, + ) -> Result<(), IMStatusCode> { + self.check_endpoint(ep) + .and_then(|endpoint| endpoint.check_command(accessor, cl, cmd)) + } + + fn match_endpoints(&self, ep: Option) -> impl Iterator + '_ { + self.endpoints + .iter() + .filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true)) + } + + fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> { + self.endpoints + .iter() + .find(|endpoint| endpoint.id == ep) + .ok_or(IMStatusCode::UnsupportedEndpoint) + } +} + +impl<'a> core::fmt::Display for Node<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "node:")?; - for (i, element) in self.endpoints.iter().enumerate() { - if let Some(e) = element { - writeln!(f, "endpoint {}: {}", i, e)?; - } + for (index, endpoint) in self.endpoints.iter().enumerate() { + writeln!(f, "endpoint {}: {}", index, endpoint)?; } + write!(f, "") } } -impl Node { - pub fn new() -> Result, Error> { - let node = Box::default(); - Ok(node) +pub struct DynamicNode<'a, const N: usize> { + id: u16, + endpoints: heapless::Vec, N>, +} + +impl<'a, const N: usize> DynamicNode<'a, N> { + pub const fn new(id: u16) -> Self { + Self { + id, + endpoints: heapless::Vec::new(), + } } - pub fn set_changes_cb(&mut self, consumer: Box) { - self.changes_cb = Some(consumer); + pub fn node(&self) -> Node<'_> { + Node { + id: self.id, + endpoints: &self.endpoints, + } } - pub fn add_endpoint(&mut self, dev_type: DeviceType) -> Result { + pub fn add(&mut self, endpoint: Endpoint<'a>) -> Result<(), Endpoint<'a>> { + if !self.endpoints.iter().any(|ep| ep.id == endpoint.id) { + self.endpoints.push(endpoint) + } else { + Err(endpoint) + } + } + + pub fn remove(&mut self, endpoint_id: u16) -> Option> { let index = self .endpoints .iter() - .position(|x| x.is_none()) - .ok_or(Error::NoSpace)?; - let mut endpoint = Endpoint::new(dev_type)?; - if let Some(cb) = &self.changes_cb { - cb.endpoint_added(index as EndptId, &mut endpoint)?; - } - self.endpoints[index] = Some(endpoint); - Ok(index as EndptId) - } + .enumerate() + .find_map(|(index, ep)| (ep.id == endpoint_id).then_some(index)); - pub fn get_endpoint(&self, endpoint_id: EndptId) -> Result<&Endpoint, Error> { - if (endpoint_id as usize) < ENDPTS_PER_ACC { - let endpoint = self.endpoints[endpoint_id as usize] - .as_ref() - .ok_or(Error::EndpointNotFound)?; - Ok(endpoint) + if let Some(index) = index { + Some(self.endpoints.swap_remove(index)) } else { - Err(Error::EndpointNotFound) + None } } - - pub fn get_endpoint_mut(&mut self, endpoint_id: EndptId) -> Result<&mut Endpoint, Error> { - if (endpoint_id as usize) < ENDPTS_PER_ACC { - let endpoint = self.endpoints[endpoint_id as usize] - .as_mut() - .ok_or(Error::EndpointNotFound)?; - Ok(endpoint) - } else { - Err(Error::EndpointNotFound) - } - } - - pub fn get_cluster_mut( - &mut self, - e: EndptId, - c: ClusterId, - ) -> Result<&mut dyn ClusterType, Error> { - self.get_endpoint_mut(e)?.get_cluster_mut(c) - } - - pub fn get_cluster(&self, e: EndptId, c: ClusterId) -> Result<&dyn ClusterType, Error> { - self.get_endpoint(e)?.get_cluster(c) - } - - pub fn add_cluster( - &mut self, - endpoint_id: EndptId, - cluster: Box, - ) -> Result<(), Error> { - let endpoint_id = endpoint_id as usize; - if endpoint_id < ENDPTS_PER_ACC { - self.endpoints[endpoint_id] - .as_mut() - .ok_or(Error::NoEndpoint)? - .add_cluster(cluster) - } else { - Err(Error::Invalid) - } - } - - // Returns a slice of endpoints, with either a single endpoint or all (wildcard) - pub fn get_wildcard_endpoints( - &self, - endpoint: Option, - ) -> Result<(&BoxedEndpoints, usize, bool), IMStatusCode> { - if let Some(e) = endpoint { - let e = e as usize; - if self.endpoints.len() <= e || self.endpoints[e].is_none() { - Err(IMStatusCode::UnsupportedEndpoint) - } else { - Ok((&self.endpoints[e..e + 1], e, false)) - } - } else { - Ok((&self.endpoints[..], 0, true)) - } - } - - pub fn get_wildcard_endpoints_mut( - &mut self, - endpoint: Option, - ) -> Result<(&mut BoxedEndpoints, usize, bool), IMStatusCode> { - if let Some(e) = endpoint { - let e = e as usize; - if self.endpoints.len() <= e || self.endpoints[e].is_none() { - Err(IMStatusCode::UnsupportedEndpoint) - } else { - Ok((&mut self.endpoints[e..e + 1], e, false)) - } - } else { - Ok((&mut self.endpoints[..], 0, true)) - } - } - - /// Run a closure for all endpoints as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_endpoint(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &Endpoint) -> Result<(), IMStatusCode>, - { - let mut current_path = *path; - let (endpoints, mut endpoint_id, wildcard) = self.get_wildcard_endpoints(path.endpoint)?; - for e in endpoints.iter() { - if let Some(e) = e { - current_path.endpoint = Some(endpoint_id as EndptId); - f(¤t_path, e.as_ref()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - endpoint_id += 1; - } - Ok(()) - } - - /// Run a closure for all endpoints (mutable) as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_endpoint_mut( - &mut self, - path: &GenericPath, - mut f: T, - ) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &mut Endpoint) -> Result<(), IMStatusCode>, - { - let mut current_path = *path; - let (endpoints, mut endpoint_id, wildcard) = - self.get_wildcard_endpoints_mut(path.endpoint)?; - for e in endpoints.iter_mut() { - if let Some(e) = e { - current_path.endpoint = Some(endpoint_id as EndptId); - f(¤t_path, e.as_mut()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - endpoint_id += 1; - } - Ok(()) - } - - /// Run a closure for all clusters as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_cluster(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_endpoint(path, |p, e| { - let mut current_path = *p; - let (clusters, wildcard) = e.get_wildcard_clusters(p.cluster)?; - for c in clusters.iter() { - current_path.cluster = Some(c.base().id); - f(¤t_path, c.as_ref()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) - } - - /// Run a closure for all clusters (mutable) as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_cluster_mut( - &mut self, - path: &GenericPath, - mut f: T, - ) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &mut dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_endpoint_mut(path, |p, e| { - let mut current_path = *p; - let (clusters, wildcard) = e.get_wildcard_clusters_mut(p.cluster)?; - - for c in clusters.iter_mut() { - current_path.cluster = Some(c.base().id); - f(¤t_path, c.as_mut()) - .or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) - } - - /// Run a closure for all attributes as specified in the path - /// - /// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour - /// of this function is to only capture the successful invocations and ignore the erroneous - /// ones. This is inline with the expected behaviour for wildcard, where it implies that - /// 'please run this operation on this wildcard path "wherever possible"' - /// - /// It is expected that if the closure that you pass here returns an error it may not reach - /// out to the caller, in case there was a wildcard path specified - pub fn for_each_attribute(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode> - where - T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>, - { - self.for_each_cluster(path, |current_path, c| { - let mut current_path = *current_path; - let (attributes, wildcard) = c - .base() - .get_wildcard_attribute(path.leaf.map(|at| at as u16))?; - for a in attributes.iter() { - current_path.leaf = Some(a.id as u32); - f(¤t_path, c).or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?; - } - Ok(()) - }) - } +} + +impl<'a, const N: usize> core::fmt::Display for DynamicNode<'a, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.node().fmt(f) + } } diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs new file mode 100644 index 0000000..44131b9 --- /dev/null +++ b/matter/src/data_model/root_endpoint.rs @@ -0,0 +1,108 @@ +use core::{borrow::Borrow, cell::RefCell}; + +use crate::{ + acl::AclMgr, + fabric::FabricMgr, + handler_chain_type, + mdns::MdnsMgr, + secure_channel::pake::PaseMgr, + utils::{epoch::Epoch, rand::Rand}, + Matter, +}; + +use super::{ + cluster_basic_information::{self, BasicInfoCluster, BasicInfoConfig}, + objects::{Cluster, EmptyHandler}, + sdm::{ + admin_commissioning::{self, AdminCommCluster}, + dev_att::DevAttDataFetcher, + failsafe::FailSafe, + general_commissioning::{self, GenCommCluster}, + noc::{self, NocCluster}, + nw_commissioning::{self, NwCommCluster}, + }, + system_model::access_control::{self, AccessControlCluster}, +}; + +pub type RootEndpointHandler<'a> = handler_chain_type!( + AccessControlCluster<'a>, + NocCluster<'a>, + AdminCommCluster<'a>, + NwCommCluster, + GenCommCluster, + BasicInfoCluster<'a> +); + +pub const CLUSTERS: [Cluster<'static>; 6] = [ + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, +]; + +pub fn handler<'a>( + endpoint_id: u16, + dev_att: &'a dyn DevAttDataFetcher, + matter: &'a Matter<'a>, +) -> RootEndpointHandler<'a> { + wrap( + endpoint_id, + matter.dev_det(), + dev_att, + matter.borrow(), + matter.borrow(), + matter.borrow(), + matter.borrow(), + matter.borrow(), + *matter.borrow(), + *matter.borrow(), + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn wrap<'a>( + endpoint_id: u16, + basic_info: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + pase: &'a RefCell, + fabric: &'a RefCell, + acl: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, + epoch: Epoch, + rand: Rand, +) -> RootEndpointHandler<'a> { + EmptyHandler + .chain( + endpoint_id, + cluster_basic_information::CLUSTER.id, + BasicInfoCluster::new(basic_info, rand), + ) + .chain( + endpoint_id, + general_commissioning::CLUSTER.id, + GenCommCluster::new(rand), + ) + .chain( + endpoint_id, + nw_commissioning::CLUSTER.id, + NwCommCluster::new(rand), + ) + .chain( + endpoint_id, + admin_commissioning::CLUSTER.id, + AdminCommCluster::new(pase, mdns_mgr, rand), + ) + .chain( + endpoint_id, + noc::CLUSTER.id, + NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), + ) + .chain( + endpoint_id, + access_control::CLUSTER.id, + AccessControlCluster::new(acl, rand), + ) +} diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index fb31722..5497426 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -15,15 +15,21 @@ * limitations under the License. */ -use crate::cmd_enter; +use core::cell::RefCell; +use core::convert::TryInto; + use crate::data_model::objects::*; -use crate::interaction_model::core::IMStatusCode; +use crate::interaction_model::core::Transaction; +use crate::mdns::MdnsMgr; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; -use crate::{error::*, interaction_model::command::CommandReq}; -use log::{error, info}; +use crate::utils::rand::Rand; +use crate::{attribute_enum, cmd_enter}; +use crate::{command_enum, error::*}; +use log::info; use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0x003C; @@ -34,120 +40,54 @@ pub enum WindowStatus { BasicWindowOpen = 2, } -#[derive(FromPrimitive)] +#[derive(Copy, Clone, Debug, FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - WindowStatus = 0, - AdminFabricIndex = 1, - AdminVendorId = 2, + WindowStatus(AttrType) = 0, + AdminFabricIndex(AttrType>) = 1, + AdminVendorId(AttrType>) = 2, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { OpenCommWindow = 0x00, OpenBasicCommWindow = 0x01, RevokeComm = 0x02, } -fn attr_window_status_new() -> Attribute { - Attribute::new( - Attributes::WindowStatus as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ) -} +command_enum!(Commands); -fn attr_admin_fabid_new() -> Attribute { - Attribute::new( - Attributes::AdminFabricIndex as u16, - AttrValue::Custom, - Access::RV, - Quality::NULLABLE, - ) -} - -fn attr_admin_vid_new() -> Attribute { - Attribute::new( - Attributes::AdminVendorId as u16, - AttrValue::Custom, - Access::RV, - Quality::NULLABLE, - ) -} - -pub struct AdminCommCluster { - pase_mgr: PaseMgr, - base: Cluster, -} - -impl ClusterType for AdminCommCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::WindowStatus) => { - let status = 1_u8; - encoder.encode(EncodeValue::Value(&status)) - } - Some(Attributes::AdminVendorId) => { - let vid = Nullable::NotNull(1_u8); - - encoder.encode(EncodeValue::Value(&vid)) - } - Some(Attributes::AdminFabricIndex) => { - let vid = Nullable::NotNull(1_u8); - encoder.encode(EncodeValue::Value(&vid)) - } - _ => { - error!("Unsupported Attribute: this shouldn't happen"); - } - } - } - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::OpenCommWindow => self.handle_command_opencomm_win(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), - } - } -} - -impl AdminCommCluster { - pub fn new(pase_mgr: PaseMgr) -> Result, Error> { - let mut c = Box::new(AdminCommCluster { - pase_mgr, - base: Cluster::new(ID)?, - }); - c.base.add_attribute(attr_window_status_new())?; - c.base.add_attribute(attr_admin_fabid_new())?; - c.base.add_attribute(attr_admin_vid_new())?; - Ok(c) - } - - fn handle_command_opencomm_win( - &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { - cmd_enter!("Open Commissioning Window"); - let req = - OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); - self.pase_mgr - .enable_pase_session(verifier, req.discriminator)?; - Err(IMStatusCode::Success) - } -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::WindowStatus as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AdminFabricIndex as u16, + Access::RV, + Quality::NULLABLE, + ), + Attribute::new( + AttributesDiscriminants::AdminVendorId as u16, + Access::RV, + Quality::NULLABLE, + ), + ], + commands: &[ + Commands::OpenCommWindow as _, + Commands::OpenBasicCommWindow as _, + Commands::RevokeComm as _, + ], +}; #[derive(FromTLV)] #[tlvargs(lifetime = "'a")] @@ -158,3 +98,94 @@ pub struct OpenCommWindowReq<'a> { iterations: u32, salt: OctetStr<'a>, } + +pub struct AdminCommCluster<'a> { + data_ver: Dataver, + pase_mgr: &'a RefCell, + mdns_mgr: &'a RefCell>, +} + +impl<'a> AdminCommCluster<'a> { + pub fn new( + pase_mgr: &'a RefCell, + mdns_mgr: &'a RefCell>, + rand: Rand, + ) -> Self { + Self { + data_ver: Dataver::new(rand), + pase_mgr, + mdns_mgr, + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::WindowStatus(codec) => codec.encode(writer, 1), + Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::NotNull(1)), + Attributes::AdminFabricIndex(codec) => { + codec.encode(writer, Nullable::NotNull(1)) + } + } + } + } else { + Ok(()) + } + } + + pub fn invoke( + &mut self, + cmd: &CmdDetails, + data: &TLVElement, + _encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?, + _ => Err(Error::CommandNotFound)?, + } + + self.data_ver.changed(); + + Ok(()) + } + + fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> { + cmd_enter!("Open Commissioning Window"); + let req = OpenCommWindowReq::from_tlv(data)?; + let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); + self.pase_mgr.borrow_mut().enable_pase_session( + verifier, + req.discriminator, + &mut self.mdns_mgr.borrow_mut(), + )?; + + Ok(()) + } +} + +impl<'a> Handler for AdminCommCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + AdminCommCluster::read(self, attr, encoder) + } + + fn invoke( + &mut self, + _transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + AdminCommCluster::invoke(self, cmd, data, encoder) + } +} + +impl<'a> NonBlockingHandler for AdminCommCluster<'a> {} + +impl<'a> ChangeNotifier<()> for AdminCommCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } +} diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index cd3c2be..54b22e6 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -17,7 +17,6 @@ use crate::{error::Error, transport::session::SessionMode}; use log::error; -use std::sync::RwLock; #[derive(PartialEq)] #[allow(dead_code)] @@ -42,26 +41,19 @@ pub enum State { Armed(ArmedCtx), } -pub struct FailSafeInner { +pub struct FailSafe { state: State, } -pub struct FailSafe { - state: RwLock, -} - impl FailSafe { - pub fn new() -> Self { - Self { - state: RwLock::new(FailSafeInner { state: State::Idle }), - } + pub const fn new() -> Self { + Self { state: State::Idle } } - pub fn arm(&self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn arm(&mut self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { + match &mut self.state { State::Idle => { - inner.state = State::Armed(ArmedCtx { + self.state = State::Armed(ArmedCtx { session_mode, timeout, noc_state: NocState::NocNotRecvd, @@ -78,9 +70,8 @@ impl FailSafe { Ok(()) } - pub fn disarm(&self, session_mode: SessionMode) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn disarm(&mut self, session_mode: SessionMode) -> Result<(), Error> { + match &mut self.state { State::Idle => { error!("Received Fail-Safe Disarm without it being armed"); return Err(Error::Invalid); @@ -102,19 +93,18 @@ impl FailSafe { } } } - inner.state = State::Idle; + self.state = State::Idle; } } Ok(()) } pub fn is_armed(&self) -> bool { - self.state.read().unwrap().state != State::Idle + self.state != State::Idle } - pub fn record_add_noc(&self, fabric_index: u8) -> Result<(), Error> { - let mut inner = self.state.write()?; - match &mut inner.state { + pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> { + match &mut self.state { State::Idle => Err(Error::Invalid), State::Armed(c) => { if c.noc_state == NocState::NocNotRecvd { @@ -128,8 +118,7 @@ impl FailSafe { } pub fn allow_noc_change(&self) -> Result { - let mut inner = self.state.write()?; - let allow = match &mut inner.state { + let allow = match &self.state { State::Idle => false, State::Armed(c) => c.noc_state == NocState::NocNotRecvd, }; diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index 0328b21..aea37c7 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -15,16 +15,19 @@ * limitations under the License. */ -use crate::cmd_enter; +use core::cell::RefCell; +use core::convert::TryInto; + use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; -use crate::interaction_model::core::IMStatusCode; -use crate::interaction_model::messages::ib; -use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; -use crate::{error::*, interaction_model::command::CommandReq}; -use log::{error, info}; -use num_derive::FromPrimitive; -use std::sync::Arc; +use crate::interaction_model::core::Transaction; +use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::transport::session::Session; +use crate::utils::rand::Rand; +use crate::{attribute_enum, cmd_enter}; +use crate::{command_enum, error::*}; +use log::info; +use strum::{EnumDiscriminants, FromRepr}; #[derive(Clone, Copy)] #[allow(dead_code)] @@ -38,65 +41,80 @@ enum CommissioningError { pub const ID: u32 = 0x0030; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - BreadCrumb = 0, - BasicCommissioningInfo = 1, - RegConfig = 2, - LocationCapability = 3, + BreadCrumb(AttrType) = 0, + BasicCommissioningInfo(()) = 1, + RegConfig(AttrType) = 2, + LocationCapability(AttrType) = 3, } -#[derive(FromPrimitive)] +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { ArmFailsafe = 0x00, - ArmFailsafeResp = 0x01, SetRegulatoryConfig = 0x02, - SetRegulatoryConfigResp = 0x03, CommissioningComplete = 0x04, +} + +command_enum!(Commands); + +#[repr(u16)] +pub enum RespCommands { + ArmFailsafeResp = 0x01, + SetRegulatoryConfigResp = 0x03, CommissioningCompleteResp = 0x05, } +#[derive(FromTLV, ToTLV)] +#[tlvargs(lifetime = "'a")] +struct CommonResponse<'a> { + error_code: u8, + debug_txt: UtfStr<'a>, +} + pub enum RegLocationType { Indoor = 0, Outdoor = 1, IndoorOutdoor = 2, } -fn attr_bread_crumb_new(bread_crumb: u64) -> Attribute { - Attribute::new( - Attributes::BreadCrumb as u16, - AttrValue::Uint64(bread_crumb), - Access::READ | Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ) -} - -fn attr_reg_config_new(reg_config: RegLocationType) -> Attribute { - Attribute::new( - Attributes::RegConfig as u16, - AttrValue::Uint8(reg_config as u8), - Access::RV, - Quality::NONE, - ) -} - -fn attr_location_capability_new(reg_config: RegLocationType) -> Attribute { - Attribute::new( - Attributes::LocationCapability as u16, - AttrValue::Uint8(reg_config as u8), - Access::RV, - Quality::FIXED, - ) -} - -fn attr_comm_info_new() -> Attribute { - Attribute::new( - Attributes::BasicCommissioningInfo as u16, - AttrValue::Custom, - Access::RV, - Quality::FIXED, - ) -} +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::BreadCrumb as u16, + Access::READ.union(Access::WRITE).union(Access::NEED_ADMIN), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::RegConfig as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::LocationCapability as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::BasicCommissioningInfo as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[ + Commands::ArmFailsafe as _, + Commands::SetRegulatoryConfig as _, + Commands::CommissioningComplete as _, + ], +}; #[derive(FromTLV, ToTLV)] struct FailSafeParams { @@ -105,143 +123,134 @@ struct FailSafeParams { } pub struct GenCommCluster { + data_ver: Dataver, expiry_len: u16, - failsafe: Arc, - base: Cluster, -} - -impl ClusterType for GenCommCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::BasicCommissioningInfo) => { - encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_struct(tag); - let _ = tw.u16(TagType::Context(0), self.expiry_len); - let _ = tw.end_container(); - })) - } - _ => { - error!("Unsupported Attribute: this shouldn't happen"); - } - } - } - - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::ArmFailsafe => self.handle_command_armfailsafe(cmd_req), - Commands::SetRegulatoryConfig => self.handle_command_setregulatoryconfig(cmd_req), - Commands::CommissioningComplete => self.handle_command_commissioningcomplete(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), - } - } + failsafe: RefCell, } impl GenCommCluster { - pub fn new() -> Result, Error> { - let failsafe = Arc::new(FailSafe::new()); - - let mut c = Box::new(GenCommCluster { + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + failsafe: RefCell::new(FailSafe::new()), // TODO: Arch-Specific expiry_len: 120, - failsafe, - base: Cluster::new(ID)?, - }); - c.base.add_attribute(attr_bread_crumb_new(0))?; - // TODO: Arch-Specific - c.base - .add_attribute(attr_reg_config_new(RegLocationType::IndoorOutdoor))?; - // TODO: Arch-Specific - c.base - .add_attribute(attr_location_capability_new(RegLocationType::IndoorOutdoor))?; - c.base.add_attribute(attr_comm_info_new())?; - - Ok(c) + } } - pub fn failsafe(&self) -> Arc { - self.failsafe.clone() + pub fn failsafe(&self) -> &RefCell { + &self.failsafe } - fn handle_command_armfailsafe(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - cmd_enter!("ARM Fail Safe"); + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::BreadCrumb(codec) => codec.encode(writer, 0), + // TODO: Arch-Specific + Attributes::RegConfig(codec) => { + codec.encode(writer, RegLocationType::IndoorOutdoor as _) + } + // TODO: Arch-Specific + Attributes::LocationCapability(codec) => { + codec.encode(writer, RegLocationType::IndoorOutdoor as _) + } + Attributes::BasicCommissioningInfo(_) => { + writer.start_struct(AttrDataWriter::TAG)?; + writer.u16(TagType::Context(0), self.expiry_len)?; + writer.end_container()?; - let p = FailSafeParams::from_tlv(&cmd_req.data)?; - let mut status = CommissioningError::Ok as u8; + writer.complete() + } + } + } + } else { + Ok(()) + } + } - if self - .failsafe - .arm(p.expiry_len, cmd_req.trans.session.get_session_mode()) - .is_err() - { - status = CommissioningError::ErrBusyWithOtherAdmin as u8; + pub fn invoke( + &mut self, + session: &mut Session, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::ArmFailsafe => self.handle_command_armfailsafe(session, data, encoder)?, + Commands::SetRegulatoryConfig => { + self.handle_command_setregulatoryconfig(data, encoder)? + } + Commands::CommissioningComplete => { + self.handle_command_commissioningcomplete(session, encoder)?; + } } - let cmd_data = CommonResponse { - error_code: status, - debug_txt: "".to_owned(), - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::ArmFailsafeResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); + self.data_ver.changed(); + Ok(()) } + fn handle_command_armfailsafe( + &mut self, + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("ARM Fail Safe"); + + let p = FailSafeParams::from_tlv(data)?; + + self.failsafe + .borrow_mut() + .arm(p.expiry_len, session.get_session_mode()) + .map_err(|e| e.remap(|_| true, Error::Busy))?; + + let cmd_data = CommonResponse { + error_code: CommissioningError::ErrBusyWithOtherAdmin as u8, + debug_txt: UtfStr::new(b""), + }; + + encoder + .with_command(RespCommands::ArmFailsafeResp as _)? + .set(&cmd_data) + } + fn handle_command_setregulatoryconfig( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Set Regulatory Config"); - let country_code = cmd_req - .data + let country_code = data .find_tag(1) - .map_err(|_| IMStatusCode::InvalidCommand)? + .map_err(|_| Error::InvalidCommand)? .slice() - .map_err(|_| IMStatusCode::InvalidCommand)?; + .map_err(|_| Error::InvalidCommand)?; info!("Received country code: {:?}", country_code); let cmd_data = CommonResponse { error_code: 0, - debug_txt: "".to_owned(), + debug_txt: UtfStr::new(b""), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::SetRegulatoryConfigResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::SetRegulatoryConfigResp as _)? + .set(&cmd_data) } fn handle_command_commissioningcomplete( &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { + session: &Session, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { cmd_enter!("Commissioning Complete"); let mut status: u8 = CommissioningError::Ok as u8; // Has to be a Case Session - if cmd_req.trans.session.get_local_fabric_idx().is_none() { + if session.get_local_fabric_idx().is_none() { status = CommissioningError::ErrInvalidAuth as u8; } @@ -249,7 +258,8 @@ impl GenCommCluster { // scope that is for this session if self .failsafe - .disarm(cmd_req.trans.session.get_session_mode()) + .borrow_mut() + .disarm(session.get_session_mode()) .is_err() { status = CommissioningError::ErrInvalidAuth as u8; @@ -257,22 +267,35 @@ impl GenCommCluster { let cmd_data = CommonResponse { error_code: status, - debug_txt: "".to_owned(), + debug_txt: UtfStr::new(b""), }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CommissioningCompleteResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) + + encoder + .with_command(RespCommands::CommissioningCompleteResp as _)? + .set(&cmd_data) } } -#[derive(FromTLV, ToTLV)] -struct CommonResponse { - error_code: u8, - debug_txt: String, +impl Handler for GenCommCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + GenCommCluster::read(self, attr, encoder) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + GenCommCluster::invoke(self, transaction.session_mut(), cmd, data, encoder) + } +} + +impl NonBlockingHandler for GenCommCluster {} + +impl ChangeNotifier<()> for GenCommCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 7c85f59..0258f3a 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -15,24 +15,25 @@ * limitations under the License. */ -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use core::cell::RefCell; +use core::convert::TryInto; use crate::acl::{AclEntry, AclMgr, AuthMode}; use crate::cert::Cert; -use crate::crypto::{self, CryptoKeyPair, KeyPair}; +use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; -use crate::interaction_model::command::CommandReq; -use crate::interaction_model::core::IMStatusCode; -use crate::interaction_model::messages::ib; +use crate::interaction_model::core::Transaction; +use crate::mdns::MdnsMgr; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; -use crate::transport::session::SessionMode; +use crate::transport::session::{Session, SessionMode}; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use crate::utils::writebuf::WriteBuf; -use crate::{cmd_enter, error::*, secure_channel}; +use crate::{attribute_enum, cmd_enter, command_enum, error::*}; use log::{error, info}; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; use super::dev_att::{DataType, DevAttDataFetcher}; use super::failsafe::FailSafe; @@ -56,6 +57,23 @@ enum NocStatus { InvalidFabricIndex = 11, } +enum NocError { + Status(NocStatus), + Error(Error), +} + +impl From for NocError { + fn from(value: NocStatus) -> Self { + Self::Status(value) + } +} + +impl From for NocError { + fn from(value: Error) -> Self { + Self::Error(value) + } +} + // Some placeholder value for now const MAX_CERT_DECLARATION_LEN: usize = 600; // Some placeholder value for now @@ -65,39 +83,80 @@ const RESP_MAX: usize = 900; pub const ID: u32 = 0x003E; -#[derive(FromPrimitive)] +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { AttReq = 0x00, - AttReqResp = 0x01, CertChainReq = 0x02, - CertChainResp = 0x03, CSRReq = 0x04, - CSRResp = 0x05, AddNOC = 0x06, - NOCResp = 0x08, UpdateFabricLabel = 0x09, RemoveFabric = 0x0a, AddTrustedRootCert = 0x0b, } -#[derive(FromPrimitive)] -pub enum Attributes { - NOCs = 0, - Fabrics = 1, - SupportedFabrics = 2, - CommissionedFabrics = 3, - TrustedRootCerts = 4, - CurrentFabricIndex = 5, +command_enum!(Commands); + +#[repr(u16)] +pub enum RespCommands { + AttReqResp = 0x01, + CertChainResp = 0x03, + CSRResp = 0x05, + NOCResp = 0x08, } -pub struct NocCluster { - base: Cluster, - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - failsafe: Arc, +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] +pub enum Attributes { + NOCs = 0, + Fabrics(()) = 1, + SupportedFabrics(AttrType) = 2, + CommissionedFabrics(AttrType) = 3, + TrustedRootCerts = 4, + CurrentFabricIndex(AttrType) = 5, } -struct NocData { + +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::CurrentFabricIndex as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Fabrics as u16, + Access::RV.union(Access::FAB_SCOPED), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::SupportedFabrics as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::CommissionedFabrics as u16, + Access::RV, + Quality::NONE, + ), + ], + commands: &[ + Commands::AttReq as _, + Commands::CertChainReq as _, + Commands::CSRReq as _, + Commands::AddNOC as _, + Commands::UpdateFabricLabel as _, + Commands::RemoveFabric as _, + Commands::AddTrustedRootCert as _, + ], +}; + +pub struct NocData { pub key_pair: KeyPair, pub root_ca: Cert, } @@ -111,459 +170,16 @@ impl NocData { } } -impl NocCluster { - pub fn new( - dev_att: Box, - fabric_mgr: Arc, - acl_mgr: Arc, - failsafe: Arc, - ) -> Result, Error> { - let mut c = Box::new(Self { - dev_att, - fabric_mgr, - acl_mgr, - failsafe, - base: Cluster::new(ID)?, - }); - let attrs = [ - Attribute::new( - Attributes::CurrentFabricIndex as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::Fabrics as u16, - AttrValue::Custom, - Access::RV | Access::FAB_SCOPED, - Quality::NONE, - ), - Attribute::new( - Attributes::SupportedFabrics as u16, - AttrValue::Uint8(MAX_SUPPORTED_FABRICS as u8), - Access::RV, - Quality::FIXED, - ), - Attribute::new( - Attributes::CommissionedFabrics as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - ]; - c.base.add_attributes(&attrs[..])?; - Ok(c) - } - - fn add_acl(&self, fab_idx: u8, admin_subject: u64) -> Result<(), Error> { - let mut acl = AclEntry::new(fab_idx, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(admin_subject)?; - self.acl_mgr.add(acl) - } - - fn _handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), NocStatus> { - let noc_data = cmd_req - .trans - .session - .take_data::() - .ok_or(NocStatus::MissingCsr)?; - - if !self - .failsafe - .allow_noc_change() - .map_err(|_| NocStatus::InsufficientPrivlege)? - { - error!("AddNOC not allowed by Fail Safe"); - return Err(NocStatus::InsufficientPrivlege); - } - - // This command's processing may take longer, send a stand alone ACK to the peer to avoid any retranmissions - let ack_send = secure_channel::common::send_mrp_standalone_ack( - cmd_req.trans.exch, - cmd_req.trans.session, - ); - if ack_send.is_err() { - error!("Error sending Standalone ACK, falling back to piggybacked ACK"); - } - - let r = AddNocReq::from_tlv(&cmd_req.data).map_err(|_| NocStatus::InvalidNOC)?; - - let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received NOC as: {}", noc_value); - let icac_value = if !r.icac_value.0.is_empty() { - let cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received ICAC as: {}", cert); - Some(cert) - } else { - None - }; - let fabric = Fabric::new( - noc_data.key_pair, - noc_data.root_ca, - icac_value, - noc_value, - r.ipk_value.0, - r.vendor_id, - ) - .map_err(|_| NocStatus::TableFull)?; - let fab_idx = self - .fabric_mgr - .add(fabric) - .map_err(|_| NocStatus::TableFull)?; - - if self.add_acl(fab_idx, r.case_admin_subject).is_err() { - error!("Failed to add ACL, what to do?"); - } - - if self.failsafe.record_add_noc(fab_idx).is_err() { - error!("Failed to record NoC in the FailSafe, what to do?"); - } - NocCluster::create_nocresponse(cmd_req.resp, NocStatus::Ok, fab_idx, "".to_owned()); - cmd_req.trans.complete(); - Ok(()) - } - - fn create_nocresponse( - tw: &mut TLVWriter, - status_code: NocStatus, - fab_idx: u8, - debug_txt: String, - ) { - let cmd_data = NocResp { - status_code: status_code as u8, - fab_idx, - debug_txt, - }; - let invoke_resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::NOCResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = invoke_resp.to_tlv(tw, TagType::Anonymous); - } - - fn handle_command_updatefablabel( - &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { - cmd_enter!("Update Fabric Label"); - let req = UpdateFabricLabelReq::from_tlv(&cmd_req.data) - .map_err(|_| IMStatusCode::InvalidDataType)?; - let label = req - .label - .to_string() - .map_err(|_| IMStatusCode::InvalidDataType)?; - - let (result, fab_idx) = - if let SessionMode::Case(c) = cmd_req.trans.session.get_session_mode() { - if self.fabric_mgr.set_label(c.fab_idx, label).is_err() { - (NocStatus::LabelConflict, c.fab_idx) - } else { - (NocStatus::Ok, c.fab_idx) - } - } else { - // Update Fabric Label not allowed - (NocStatus::InvalidFabricIndex, 0) - }; - NocCluster::create_nocresponse(cmd_req.resp, result, fab_idx, "".to_string()); - cmd_req.trans.complete(); - Ok(()) - } - - fn handle_command_rmfabric(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - cmd_enter!("Remove Fabric"); - let req = - RemoveFabricReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - if self.fabric_mgr.remove(req.fab_idx).is_ok() { - let _ = self.acl_mgr.delete_for_fabric(req.fab_idx); - cmd_req.trans.terminate(); - } else { - NocCluster::create_nocresponse( - cmd_req.resp, - NocStatus::InvalidFabricIndex, - req.fab_idx, - "".to_string(), - ); - } - Ok(()) - } - - fn handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - cmd_enter!("AddNOC"); - if let Err(e) = self._handle_command_addnoc(cmd_req) { - //TODO: Fab-idx 0? - NocCluster::create_nocresponse(cmd_req.resp, e, 0, "".to_owned()); - cmd_req.trans.complete(); - } - Ok(()) - } - - fn handle_command_attrequest(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - cmd_enter!("AttestationRequest"); - - let req = CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - info!("Received Attestation Nonce:{:?}", req.str); - - let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(cmd_req.trans.session.get_att_challenge()); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut attest_element = WriteBuf::new(&mut buf, RESP_MAX); - let _ = t.start_struct(tag); - let _ = - add_attestation_element(self.dev_att.as_ref(), req.str.0, &mut attest_element, t); - let _ = add_attestation_signature( - self.dev_att.as_ref(), - &mut attest_element, - &attest_challenge, - t, - ); - let _ = t.end_container(); - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::AttReqResp as u16, - EncodeValue::Closure(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) - } - - fn handle_command_certchainrequest( - &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { - cmd_enter!("CertChainRequest"); - - info!("Received data: {}", cmd_req.data); - let cert_type = - get_certchainrequest_params(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let len = self - .dev_att - .get_devatt_data(cert_type, &mut buf) - .map_err(|_| IMStatusCode::Failure)?; - let buf = &buf[0..len]; - - let cmd_data = CertChainResp { - cert: OctetStr::new(buf), - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CertChainResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - Ok(()) - } - - fn handle_command_csrrequest(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - cmd_enter!("CSRRequest"); - - let req = CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - info!("Received CSR Nonce:{:?}", req.str); - - if !self.failsafe.is_armed() { - return Err(IMStatusCode::UnsupportedAccess); - } - - let noc_keypair = KeyPair::new().map_err(|_| IMStatusCode::Failure)?; - let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; - attest_challenge.copy_from_slice(cmd_req.trans.session.get_att_challenge()); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut nocsr_element = WriteBuf::new(&mut buf, RESP_MAX); - let _ = t.start_struct(tag); - let _ = add_nocsrelement(&noc_keypair, req.str.0, &mut nocsr_element, t); - let _ = add_attestation_signature( - self.dev_att.as_ref(), - &mut nocsr_element, - &attest_challenge, - t, - ); - let _ = t.end_container(); - }; - let resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::CSRResp as u16, - EncodeValue::Closure(&cmd_data), - ); - - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); - let noc_data = Box::new(NocData::new(noc_keypair)); - // Store this in the session data instead of cluster data, so it gets cleared - // if the session goes away for some reason - cmd_req.trans.session.set_data(noc_data); - cmd_req.trans.complete(); - Ok(()) - } - - fn handle_command_addtrustedrootcert( - &mut self, - cmd_req: &mut CommandReq, - ) -> Result<(), IMStatusCode> { - cmd_enter!("AddTrustedRootCert"); - if !self.failsafe.is_armed() { - return Err(IMStatusCode::UnsupportedAccess); - } - - // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary - match cmd_req.trans.session.get_session_mode() { - SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, - SessionMode::Pase => { - let noc_data = cmd_req - .trans - .session - .get_data::() - .ok_or(IMStatusCode::Failure)?; - - let req = - CommonReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - info!("Received Trusted Cert:{:x?}", req.str); - - noc_data.root_ca = Cert::new(req.str.0).map_err(|_| IMStatusCode::Failure)?; - } - _ => (), - } - cmd_req.trans.complete(); - - Err(IMStatusCode::Success) - } -} - -impl ClusterType for NocCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - Commands::AddNOC => self.handle_command_addnoc(cmd_req), - Commands::CSRReq => self.handle_command_csrrequest(cmd_req), - Commands::AddTrustedRootCert => self.handle_command_addtrustedrootcert(cmd_req), - Commands::AttReq => self.handle_command_attrequest(cmd_req), - Commands::CertChainReq => self.handle_command_certchainrequest(cmd_req), - Commands::UpdateFabricLabel => self.handle_command_updatefablabel(cmd_req), - Commands::RemoveFabric => self.handle_command_rmfabric(cmd_req), - _ => Err(IMStatusCode::UnsupportedCommand), - } - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::CurrentFabricIndex) => { - encoder.encode(EncodeValue::Value(&attr.fab_idx)) - } - Some(Attributes::Fabrics) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - let _ = self.fabric_mgr.for_each(|entry, fab_idx| { - if !attr.fab_filter || attr.fab_idx == fab_idx { - let _ = entry - .get_fabric_desc(fab_idx) - .to_tlv(tw, TagType::Anonymous); - } - }); - let _ = tw.end_container(); - })), - Some(Attributes::CommissionedFabrics) => { - let count = self.fabric_mgr.used_count() as u8; - encoder.encode(EncodeValue::Value(&count)) - } - _ => { - error!("Attribute not supported: this shouldn't happen"); - } - } - } -} - -fn add_attestation_element( - dev_att: &dyn DevAttDataFetcher, - att_nonce: &[u8], - write_buf: &mut WriteBuf, - t: &mut TLVWriter, -) -> Result<(), Error> { - let mut cert_dec: [u8; MAX_CERT_DECLARATION_LEN] = [0; MAX_CERT_DECLARATION_LEN]; - let len = dev_att.get_devatt_data(dev_att::DataType::CertDeclaration, &mut cert_dec)?; - let cert_dec = &cert_dec[0..len]; - - let epoch = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as u32; - let mut writer = TLVWriter::new(write_buf); - writer.start_struct(TagType::Anonymous)?; - writer.str16(TagType::Context(1), cert_dec)?; - writer.str8(TagType::Context(2), att_nonce)?; - writer.u32(TagType::Context(3), epoch)?; - writer.end_container()?; - - t.str16(TagType::Context(0), write_buf.as_borrow_slice())?; - Ok(()) -} - -fn add_attestation_signature( - dev_att: &dyn DevAttDataFetcher, - attest_element: &mut WriteBuf, - attest_challenge: &[u8], - resp: &mut TLVWriter, -) -> Result<(), Error> { - let dac_key = { - let mut pubkey = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let mut privkey = [0_u8; crypto::BIGNUM_LEN_BYTES]; - dev_att.get_devatt_data(dev_att::DataType::DACPubKey, &mut pubkey)?; - dev_att.get_devatt_data(dev_att::DataType::DACPrivKey, &mut privkey)?; - KeyPair::new_from_components(&pubkey, &privkey) - }?; - attest_element.copy_from_slice(attest_challenge)?; - let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; - dac_key.sign_msg(attest_element.as_borrow_slice(), &mut signature)?; - resp.str8(TagType::Context(1), &signature) -} - -fn add_nocsrelement( - noc_keypair: &KeyPair, - csr_nonce: &[u8], - write_buf: &mut WriteBuf, - resp: &mut TLVWriter, -) -> Result<(), Error> { - let mut csr: [u8; MAX_CSR_LEN] = [0; MAX_CSR_LEN]; - let csr = noc_keypair.get_csr(&mut csr)?; - let mut writer = TLVWriter::new(write_buf); - writer.start_struct(TagType::Anonymous)?; - writer.str8(TagType::Context(1), csr)?; - writer.str8(TagType::Context(2), csr_nonce)?; - writer.end_container()?; - - resp.str8(TagType::Context(0), write_buf.as_borrow_slice())?; - Ok(()) -} - #[derive(ToTLV)] struct CertChainResp<'a> { cert: OctetStr<'a>, } #[derive(ToTLV)] -struct NocResp { +struct NocResp<'a> { status_code: u8, fab_idx: u8, - debug_txt: String, + debug_txt: UtfStr<'a>, } #[derive(FromTLV)] @@ -598,6 +214,475 @@ struct RemoveFabricReq { fab_idx: u8, } +pub struct NocCluster<'a> { + data_ver: Dataver, + epoch: Epoch, + dev_att: &'a dyn DevAttDataFetcher, + fabric_mgr: &'a RefCell, + acl_mgr: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, +} + +impl<'a> NocCluster<'a> { + pub fn new( + dev_att: &'a dyn DevAttDataFetcher, + fabric_mgr: &'a RefCell, + acl_mgr: &'a RefCell, + failsafe: &'a RefCell, + mdns_mgr: &'a RefCell>, + epoch: Epoch, + rand: Rand, + ) -> Self { + Self { + data_ver: Dataver::new(rand), + epoch, + dev_att, + fabric_mgr, + acl_mgr, + failsafe, + mdns_mgr, + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::SupportedFabrics(codec) => { + codec.encode(writer, MAX_SUPPORTED_FABRICS as _) + } + Attributes::CurrentFabricIndex(codec) => codec.encode(writer, attr.fab_idx), + Attributes::Fabrics(_) => { + writer.start_array(AttrDataWriter::TAG)?; + self.fabric_mgr.borrow().for_each(|entry, fab_idx| { + if !attr.fab_filter || attr.fab_idx == fab_idx { + entry + .get_fabric_desc(fab_idx) + .to_tlv(&mut writer, TagType::Anonymous)?; + } + + Ok(()) + })?; + writer.end_container()?; + + writer.complete() + } + Attributes::CommissionedFabrics(codec) => { + codec.encode(writer, self.fabric_mgr.borrow().used_count() as _) + } + _ => { + error!("Attribute not supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } + } + } else { + Ok(()) + } + } + + pub fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + Commands::AddNOC => { + self.handle_command_addnoc(transaction.session_mut(), data, encoder)? + } + Commands::CSRReq => { + self.handle_command_csrrequest(transaction.session_mut(), data, encoder)? + } + Commands::AddTrustedRootCert => { + self.handle_command_addtrustedrootcert(transaction.session_mut(), data)? + } + Commands::AttReq => { + self.handle_command_attrequest(transaction.session(), data, encoder)? + } + Commands::CertChainReq => self.handle_command_certchainrequest(data, encoder)?, + Commands::UpdateFabricLabel => { + self.handle_command_updatefablabel(transaction.session(), data, encoder)?; + } + Commands::RemoveFabric => self.handle_command_rmfabric(transaction, data, encoder)?, + } + + self.data_ver.changed(); + + Ok(()) + } + + fn add_acl(&self, fab_idx: u8, admin_subject: u64) -> Result<(), Error> { + let mut acl = AclEntry::new(fab_idx, Privilege::ADMIN, AuthMode::Case); + acl.add_subject(admin_subject)?; + self.acl_mgr.borrow_mut().add(acl) + } + + fn _handle_command_addnoc( + &mut self, + session: &mut Session, + data: &TLVElement, + ) -> Result { + let noc_data = session.take_noc_data().ok_or(NocStatus::MissingCsr)?; + + if !self + .failsafe + .borrow_mut() + .allow_noc_change() + .map_err(|_| NocStatus::InsufficientPrivlege)? + { + error!("AddNOC not allowed by Fail Safe"); + Err(NocStatus::InsufficientPrivlege)?; + } + + // TODO + // // This command's processing may take longer, send a stand alone ACK to the peer to avoid any retranmissions + // let ack_send = secure_channel::common::send_mrp_standalone_ack( + // trans.exch, + // trans.session, + // ); + // if ack_send.is_err() { + // error!("Error sending Standalone ACK, falling back to piggybacked ACK"); + // } + + let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; + + let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received NOC as: {}", noc_value); + let icac_value = if !r.icac_value.0.is_empty() { + let cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received ICAC as: {}", cert); + Some(cert) + } else { + None + }; + + let fabric = Fabric::new( + noc_data.key_pair, + noc_data.root_ca, + icac_value, + noc_value, + r.ipk_value.0, + r.vendor_id, + "", + ) + .map_err(|_| NocStatus::TableFull)?; + let fab_idx = self + .fabric_mgr + .borrow_mut() + .add(fabric, &mut self.mdns_mgr.borrow_mut()) + .map_err(|_| NocStatus::TableFull)?; + + self.add_acl(fab_idx, r.case_admin_subject)?; + + self.failsafe.borrow_mut().record_add_noc(fab_idx)?; + + Ok(fab_idx) + } + + fn create_nocresponse( + encoder: CmdDataEncoder, + status_code: NocStatus, + fab_idx: u8, + debug_txt: &str, + ) -> Result<(), Error> { + let cmd_data = NocResp { + status_code: status_code as u8, + fab_idx, + debug_txt: UtfStr::new(debug_txt.as_bytes()), + }; + + encoder + .with_command(RespCommands::NOCResp as _)? + .set(&cmd_data) + } + + fn handle_command_updatefablabel( + &mut self, + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("Update Fabric Label"); + let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; + let (result, fab_idx) = if let SessionMode::Case(c) = session.get_session_mode() { + if self + .fabric_mgr + .borrow_mut() + .set_label( + c.fab_idx, + req.label.as_str().map_err(Error::map_invalid_data_type)?, + ) + .is_err() + { + (NocStatus::LabelConflict, c.fab_idx) + } else { + (NocStatus::Ok, c.fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; + + Self::create_nocresponse(encoder, result, fab_idx, "") + } + + fn handle_command_rmfabric( + &mut self, + transaction: &mut Transaction, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("Remove Fabric"); + let req = RemoveFabricReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; + if self + .fabric_mgr + .borrow_mut() + .remove(req.fab_idx, &mut self.mdns_mgr.borrow_mut()) + .is_ok() + { + let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); + transaction.terminate(); + Ok(()) + } else { + Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") + } + } + + fn handle_command_addnoc( + &mut self, + session: &mut Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("AddNOC"); + + let (status, fab_idx) = match self._handle_command_addnoc(session, data) { + Ok(fab_idx) => (NocStatus::Ok, fab_idx), + Err(NocError::Status(status)) => (status, 0), + Err(NocError::Error(error)) => Err(error)?, + }; + + Self::create_nocresponse(encoder, status, fab_idx, "") + } + + fn handle_command_attrequest( + &mut self, + session: &Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("AttestationRequest"); + + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; + info!("Received Attestation Nonce:{:?}", req.str); + + let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; + attest_challenge.copy_from_slice(session.get_att_challenge()); + + let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; + + let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; + let mut attest_element = WriteBuf::new(&mut buf); + writer.start_struct(CmdDataWriter::TAG)?; + add_attestation_element( + self.epoch, + self.dev_att, + req.str.0, + &mut attest_element, + &mut writer, + )?; + add_attestation_signature( + self.dev_att, + &mut attest_element, + &attest_challenge, + &mut writer, + )?; + writer.end_container()?; + + writer.complete() + } + + fn handle_command_certchainrequest( + &mut self, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("CertChainRequest"); + + info!("Received data: {}", data); + let cert_type = get_certchainrequest_params(data).map_err(Error::map_invalid_command)?; + + let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; + let len = self.dev_att.get_devatt_data(cert_type, &mut buf)?; + let buf = &buf[0..len]; + + let cmd_data = CertChainResp { + cert: OctetStr::new(buf), + }; + + encoder + .with_command(RespCommands::CertChainResp as _)? + .set(&cmd_data) + } + + fn handle_command_csrrequest( + &mut self, + session: &mut Session, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + cmd_enter!("CSRRequest"); + + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; + info!("Received CSR Nonce:{:?}", req.str); + + if !self.failsafe.borrow().is_armed() { + return Err(Error::UnsupportedAccess); + } + + let noc_keypair = KeyPair::new()?; + let mut attest_challenge = [0u8; crypto::SYMM_KEY_LEN_BYTES]; + attest_challenge.copy_from_slice(session.get_att_challenge()); + + let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; + + let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; + let mut nocsr_element = WriteBuf::new(&mut buf); + writer.start_struct(CmdDataWriter::TAG)?; + add_nocsrelement(&noc_keypair, req.str.0, &mut nocsr_element, &mut writer)?; + add_attestation_signature( + self.dev_att, + &mut nocsr_element, + &attest_challenge, + &mut writer, + )?; + writer.end_container()?; + + writer.complete()?; + + let noc_data = NocData::new(noc_keypair); + // Store this in the session data instead of cluster data, so it gets cleared + // if the session goes away for some reason + session.set_noc_data(noc_data); + + Ok(()) + } + + fn handle_command_addtrustedrootcert( + &mut self, + session: &mut Session, + data: &TLVElement, + ) -> Result<(), Error> { + cmd_enter!("AddTrustedRootCert"); + if !self.failsafe.borrow().is_armed() { + return Err(Error::UnsupportedAccess); + } + + // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary + match session.get_session_mode() { + SessionMode::Case(_) => error!("CASE: AddTrustedRootCert handling pending"), // For a CASE Session, we just return success for now, + SessionMode::Pase => { + let noc_data = session.get_noc_data::().ok_or(Error::NoSession)?; + + let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; + info!("Received Trusted Cert:{:x?}", req.str); + + noc_data.root_ca = Cert::new(req.str.0)?; + } + _ => (), + } + + Ok(()) + } +} + +impl<'a> Handler for NocCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + NocCluster::read(self, attr, encoder) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + NocCluster::invoke(self, transaction, cmd, data, encoder) + } +} + +impl<'a> NonBlockingHandler for NocCluster<'a> {} + +impl<'a> ChangeNotifier<()> for NocCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } +} + +fn add_attestation_element( + epoch: Epoch, + dev_att: &dyn DevAttDataFetcher, + att_nonce: &[u8], + write_buf: &mut WriteBuf, + t: &mut TLVWriter, +) -> Result<(), Error> { + let mut cert_dec: [u8; MAX_CERT_DECLARATION_LEN] = [0; MAX_CERT_DECLARATION_LEN]; + let len = dev_att.get_devatt_data(dev_att::DataType::CertDeclaration, &mut cert_dec)?; + let cert_dec = &cert_dec[0..len]; + + let epoch = epoch().as_secs() as u32; + let mut writer = TLVWriter::new(write_buf); + writer.start_struct(TagType::Anonymous)?; + writer.str16(TagType::Context(1), cert_dec)?; + writer.str8(TagType::Context(2), att_nonce)?; + writer.u32(TagType::Context(3), epoch)?; + writer.end_container()?; + + t.str16(TagType::Context(0), write_buf.as_slice()) +} + +fn add_attestation_signature( + dev_att: &dyn DevAttDataFetcher, + attest_element: &mut WriteBuf, + attest_challenge: &[u8], + resp: &mut TLVWriter, +) -> Result<(), Error> { + let dac_key = { + let mut pubkey = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let mut privkey = [0_u8; crypto::BIGNUM_LEN_BYTES]; + dev_att.get_devatt_data(dev_att::DataType::DACPubKey, &mut pubkey)?; + dev_att.get_devatt_data(dev_att::DataType::DACPrivKey, &mut privkey)?; + KeyPair::new_from_components(&pubkey, &privkey) + }?; + attest_element.copy_from_slice(attest_challenge)?; + let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; + dac_key.sign_msg(attest_element.as_slice(), &mut signature)?; + resp.str8(TagType::Context(1), &signature) +} + +fn add_nocsrelement( + noc_keypair: &KeyPair, + csr_nonce: &[u8], + write_buf: &mut WriteBuf, + resp: &mut TLVWriter, +) -> Result<(), Error> { + let mut csr: [u8; MAX_CSR_LEN] = [0; MAX_CSR_LEN]; + let csr = noc_keypair.get_csr(&mut csr)?; + let mut writer = TLVWriter::new(write_buf); + writer.start_struct(TagType::Anonymous)?; + writer.str8(TagType::Context(1), csr)?; + writer.str8(TagType::Context(2), csr_nonce)?; + writer.end_container()?; + + resp.str8(TagType::Context(0), write_buf.as_slice()) +} + fn get_certchainrequest_params(data: &TLVElement) -> Result { let cert_type = CertChainReq::from_tlv(data)?.cert_type; diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 753347d..7afff7a 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -16,38 +16,51 @@ */ use crate::{ - data_model::objects::{Cluster, ClusterType}, + data_model::objects::{ + AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler, + NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP, + }, error::Error, + utils::rand::Rand, }; pub const ID: u32 = 0x0031; -pub struct NwCommCluster { - base: Cluster, -} - -impl ClusterType for NwCommCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } -} - enum FeatureMap { _Wifi = 0x01, _Thread = 0x02, Ethernet = 0x04, } +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: FeatureMap::Ethernet as _, + attributes: &[FEATURE_MAP, ATTRIBUTE_LIST], + commands: &[], +}; + +pub struct NwCommCluster { + data_ver: Dataver, +} + impl NwCommCluster { - pub fn new() -> Result, Error> { - let mut c = Box::new(Self { - base: Cluster::new(ID)?, - }); - // TODO: Arch-Specific - c.base.set_feature_map(FeatureMap::Ethernet as u32)?; - Ok(c) + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } + } +} + +impl Handler for NwCommCluster { + fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { + Err(Error::AttributeNotFound) + } +} + +impl NonBlockingHandler for NwCommCluster {} + +impl ChangeNotifier<()> for NwCommCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index df1297b..3980a43 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -15,46 +15,130 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; +use core::convert::TryInto; -use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; -use crate::acl::{self, AclEntry, AclMgr}; +use crate::acl::{AclEntry, AclMgr}; use crate::data_model::objects::*; -use crate::error::*; -use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; +use crate::utils::rand::Rand; +use crate::{attribute_enum, error::*}; use log::{error, info}; pub const ID: u32 = 0x001F; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] pub enum Attributes { - Acl = 0, - Extension = 1, - SubjectsPerEntry = 2, - TargetsPerEntry = 3, - EntriesPerFabric = 4, + Acl(()) = 0, + Extension(()) = 1, + SubjectsPerEntry(AttrType) = 2, + TargetsPerEntry(AttrType) = 3, + EntriesPerFabric(AttrType) = 4, } -pub struct AccessControlCluster { - base: Cluster, - acl_mgr: Arc, +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::Acl as u16, + Access::RWFA, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Extension as u16, + Access::RWFA, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::SubjectsPerEntry as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::TargetsPerEntry as u16, + Access::RV, + Quality::FIXED, + ), + Attribute::new( + AttributesDiscriminants::EntriesPerFabric as u16, + Access::RV, + Quality::FIXED, + ), + ], + commands: &[], +}; + +pub struct AccessControlCluster<'a> { + data_ver: Dataver, + acl_mgr: &'a RefCell, } -impl AccessControlCluster { - pub fn new(acl_mgr: Arc) -> Result, Error> { - let mut c = Box::new(AccessControlCluster { - base: Cluster::new(ID)?, +impl<'a> AccessControlCluster<'a> { + pub fn new(acl_mgr: &'a RefCell, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), acl_mgr, - }); - c.base.add_attribute(attr_acl_new())?; - c.base.add_attribute(attr_extension_new())?; - c.base.add_attribute(attr_subjects_per_entry_new())?; - c.base.add_attribute(attr_targets_per_entry_new())?; - c.base.add_attribute(attr_entries_per_fabric_new())?; - Ok(c) + } + } + + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::Acl(_) => { + writer.start_array(AttrDataWriter::TAG)?; + self.acl_mgr.borrow().for_each_acl(|entry| { + if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx { + entry.to_tlv(&mut writer, TagType::Anonymous)?; + } + + Ok(()) + })?; + writer.end_container()?; + + writer.complete() + } + Attributes::Extension(_) => { + // Empty for now + writer.start_array(AttrDataWriter::TAG)?; + writer.end_container()?; + + writer.complete() + } + _ => { + error!("Attribute not yet supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } + } + } else { + Ok(()) + } + } + + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + match attr.attr_id.try_into()? { + Attributes::Acl(_) => { + attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { + self.write_acl_attr(&op, data, attr.fab_idx) + }) + } + _ => { + error!("Attribute not yet supported: this shouldn't happen"); + Err(Error::AttributeNotFound) + } + } } /// Write the ACL Attribute @@ -66,141 +150,59 @@ impl AccessControlCluster { op: &ListOperation, data: &TLVElement, fab_idx: u8, - ) -> Result<(), IMStatusCode> { + ) -> Result<(), Error> { info!("Performing ACL operation {:?}", op); - let result = match op { + match op { ListOperation::AddItem | ListOperation::EditItem(_) => { - let mut acl_entry = - AclEntry::from_tlv(data).map_err(|_| IMStatusCode::ConstraintError)?; + let mut acl_entry = AclEntry::from_tlv(data)?; info!("ACL {:?}", acl_entry); // Overwrite the fabric index with our accessing fabric index acl_entry.fab_idx = Some(fab_idx); if let ListOperation::EditItem(index) = op { - self.acl_mgr.edit(*index as u8, fab_idx, acl_entry) + self.acl_mgr + .borrow_mut() + .edit(*index as u8, fab_idx, acl_entry) } else { - self.acl_mgr.add(acl_entry) + self.acl_mgr.borrow_mut().add(acl_entry) } } - ListOperation::DeleteItem(index) => self.acl_mgr.delete(*index as u8, fab_idx), - ListOperation::DeleteList => self.acl_mgr.delete_for_fabric(fab_idx), - }; - match result { - Ok(_) => Ok(()), - Err(Error::NoSpace) => Err(IMStatusCode::ResourceExhausted), - _ => Err(IMStatusCode::ConstraintError), - } - } -} - -impl ClusterType for AccessControlCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::Acl) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - let _ = self.acl_mgr.for_each_acl(|entry| { - if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx { - let _ = entry.to_tlv(tw, TagType::Anonymous); - } - }); - let _ = tw.end_container(); - })), - Some(Attributes::Extension) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - // Empty for now - let _ = tw.start_array(tag); - let _ = tw.end_container(); - })), - _ => { - error!("Attribute not yet supported: this shouldn't happen"); + ListOperation::DeleteItem(index) => { + self.acl_mgr.borrow_mut().delete(*index as u8, fab_idx) } + ListOperation::DeleteList => self.acl_mgr.borrow_mut().delete_for_fabric(fab_idx), } } +} - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - let result = if let Some(Attributes::Acl) = num::FromPrimitive::from_u16(attr.attr_id) { - attr_list_write(attr, data, |op, data| { - self.write_acl_attr(&op, data, attr.fab_idx) - }) - } else { - error!("Attribute not yet supported: this shouldn't happen"); - Err(IMStatusCode::NotFound) - }; - if result.is_ok() { - self.base.cluster_changed(); - } - result +impl<'a> Handler for AccessControlCluster<'a> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + AccessControlCluster::read(self, attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + AccessControlCluster::write(self, attr, data) } } -fn attr_acl_new() -> Attribute { - Attribute::new( - Attributes::Acl as u16, - AttrValue::Custom, - Access::RWFA, - Quality::NONE, - ) -} +impl<'a> NonBlockingHandler for AccessControlCluster<'a> {} -fn attr_extension_new() -> Attribute { - Attribute::new( - Attributes::Extension as u16, - AttrValue::Custom, - Access::RWFA, - Quality::NONE, - ) -} - -fn attr_subjects_per_entry_new() -> Attribute { - Attribute::new( - Attributes::SubjectsPerEntry as u16, - AttrValue::Uint16(acl::SUBJECTS_PER_ENTRY as u16), - Access::RV, - Quality::FIXED, - ) -} - -fn attr_targets_per_entry_new() -> Attribute { - Attribute::new( - Attributes::TargetsPerEntry as u16, - AttrValue::Uint16(acl::TARGETS_PER_ENTRY as u16), - Access::RV, - Quality::FIXED, - ) -} - -fn attr_entries_per_fabric_new() -> Attribute { - Attribute::new( - Attributes::EntriesPerFabric as u16, - AttrValue::Uint16(acl::ENTRIES_PER_FABRIC as u16), - Access::RV, - Quality::FIXED, - ) +impl<'a> ChangeNotifier<()> for AccessControlCluster<'a> { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) + } } #[cfg(test)] mod tests { - use std::sync::Arc; + use core::cell::RefCell; use crate::{ acl::{AclEntry, AclMgr, AuthMode}, - data_model::{ - core::read::AttrReadEncoder, - objects::{AttrDetails, ClusterType, Privilege}, - }, + data_model::objects::{AttrDataEncoder, AttrDetails, Node, Privilege}, interaction_model::messages::ib::ListOperation, tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::{rand::dummy_rand, writebuf::WriteBuf}, }; use super::AccessControlCluster; @@ -209,16 +211,15 @@ mod tests { /// Add an ACL entry fn acl_cluster_add() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let acl_mgr = RefCell::new(AclMgr::new()); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); + let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, ACL has fabric index 2, but the accessing fabric is 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1 @@ -227,8 +228,10 @@ mod tests { let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier); + Ok(()) }) .unwrap(); } @@ -237,25 +240,24 @@ mod tests { /// - The listindex used for edit should be relative to the current fabric fn acl_cluster_edit() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let mut verifier = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in verifier { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); + let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); @@ -266,9 +268,11 @@ mod tests { // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier[index]); index += 1; + Ok(()) }) .unwrap(); } @@ -277,16 +281,16 @@ mod tests { /// - The listindex used for delete should be relative to the current fabric fn acl_cluster_delete() { // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let input = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in input { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); + let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // data is don't-care actually let data = TLVElement::new(TagType::Anonymous, ElementType::True); @@ -298,9 +302,11 @@ mod tests { // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; acl_mgr + .borrow() .for_each_acl(|a| { assert_eq!(*a, verifier[index]); index += 1; + Ok(()) }) .unwrap(); } @@ -309,84 +315,126 @@ mod tests { /// - acl read with and without fabric filtering fn acl_cluster_read() { let mut buf: [u8; 100] = [0; 100]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut writebuf = WriteBuf::new(&mut buf); // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let acl_mgr = RefCell::new(AclMgr::new()); let input = [ AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), ]; for i in input { - acl_mgr.add(i).unwrap(); + acl_mgr.borrow_mut().add(i).unwrap(); } - let acl = AccessControlCluster::new(acl_mgr).unwrap(); + let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); // Test 1, all 3 entries are read in the response without fabric filtering { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 1, fab_filter: false, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, + // 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, + // 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, - 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, - 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, + 3, 24, 54, 4, 24, 36, 254, 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, + 36, 254, 2, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } writebuf.reset(0); // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 1, fab_filter: true, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 1, 24, 24, 24, 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 1, 24, 24, 24, 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } writebuf.reset(0); // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 { - let mut tw = TLVWriter::new(&mut writebuf); - let mut encoder = AttrReadEncoder::new(&mut tw); - - let attr_details = AttrDetails { + let attr = AttrDetails { + node: &Node { + id: 0, + endpoints: &[], + }, + endpoint_id: 0, + cluster_id: 0, attr_id: 0, list_index: None, fab_idx: 2, fab_filter: true, + dataver: None, + wildcard: false, }; - acl.read_custom_attribute(&mut encoder, &attr_details); + + let mut tw = TLVWriter::new(&mut writebuf); + let encoder = AttrDataEncoder::new(&attr, &mut tw); + + acl.read(&attr, encoder).unwrap(); assert_eq!( + // &[ + // 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, + // 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, + // 2, 24, 24, 24, 24 + // ], &[ - 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, - 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, - 2, 24, 24, 24, 24 + 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1, + 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, + 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 24 ], - writebuf.as_borrow_slice() + writebuf.as_slice() ); } } diff --git a/matter/src/data_model/system_model/descriptor.rs b/matter/src/data_model/system_model/descriptor.rs index 4fba0fa..2df17f5 100644 --- a/matter/src/data_model/system_model/descriptor.rs +++ b/matter/src/data_model/system_model/descriptor.rs @@ -15,18 +15,20 @@ * limitations under the License. */ -use num_derive::FromPrimitive; +use core::convert::TryInto; -use crate::data_model::core::DataModel; +use strum::FromRepr; + +use crate::attribute_enum; use crate::data_model::objects::*; -use crate::error::*; -use crate::interaction_model::messages::GenericPath; +use crate::error::Error; use crate::tlv::{TLVWriter, TagType, ToTLV}; -use log::error; +use crate::utils::rand::Rand; pub const ID: u32 = 0x001D; -#[derive(FromPrimitive)] +#[derive(FromRepr)] +#[repr(u16)] #[allow(clippy::enum_variant_names)] pub enum Attributes { DeviceTypeList = 0, @@ -35,134 +37,155 @@ pub enum Attributes { PartsList = 3, } +attribute_enum!(Attributes); + +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID as _, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new(Attributes::DeviceTypeList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::ServerList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::PartsList as u16, Access::RV, Quality::NONE), + Attribute::new(Attributes::ClientList as u16, Access::RV, Quality::NONE), + ], + commands: &[], +}; + pub struct DescriptorCluster { - base: Cluster, - endpoint_id: EndptId, - data_model: DataModel, + data_ver: Dataver, } impl DescriptorCluster { - pub fn new(endpoint_id: EndptId, data_model: DataModel) -> Result, Error> { - let mut c = Box::new(DescriptorCluster { - endpoint_id, - data_model, - base: Cluster::new(ID)?, - }); - let attrs = [ - Attribute::new( - Attributes::DeviceTypeList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::ServerList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::PartsList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - Attribute::new( - Attributes::ClientList as u16, - AttrValue::Custom, - Access::RV, - Quality::NONE, - ), - ]; - c.base.add_attributes(&attrs[..])?; - Ok(c) + pub fn new(rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), + } } - fn encode_devtype_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: Some(self.endpoint_id), - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_endpoint(&path, |_, e| { - let dev_type = e.get_dev_type(); - let _ = dev_type.to_tlv(tw, TagType::Anonymous); - Ok(()) - }); - let _ = tw.end_container(); - } - - fn encode_server_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: Some(self.endpoint_id), - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_cluster(&path, |_current_path, c| { - let _ = tw.u32(TagType::Anonymous, c.base().id()); - Ok(()) - }); - let _ = tw.end_container(); - } - - fn encode_parts_list(&self, tag: TagType, tw: &mut TLVWriter) { - let path = GenericPath { - endpoint: None, - cluster: None, - leaf: None, - }; - let _ = tw.start_array(tag); - if self.endpoint_id == 0 { - // TODO: If endpoint is another than 0, need to figure out what to do - let dm = self.data_model.node.read().unwrap(); - let _ = dm.for_each_endpoint(&path, |current_path, _| { - if let Some(endpoint_id) = current_path.endpoint { - if endpoint_id != 0 { - let _ = tw.u16(TagType::Anonymous, endpoint_id); + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::DeviceTypeList => { + self.encode_devtype_list(attr.node, AttrDataWriter::TAG, &mut writer)?; + writer.complete() + } + Attributes::ServerList => { + self.encode_server_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() + } + Attributes::PartsList => { + self.encode_parts_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() + } + Attributes::ClientList => { + self.encode_client_list( + attr.node, + attr.endpoint_id, + AttrDataWriter::TAG, + &mut writer, + )?; + writer.complete() } } - Ok(()) - }); + } + } else { + Ok(()) } - let _ = tw.end_container(); } - fn encode_client_list(&self, tag: TagType, tw: &mut TLVWriter) { - // No Clients supported - let _ = tw.start_array(tag); - let _ = tw.end_container(); - } -} + fn encode_devtype_list( + &self, + node: &Node, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + for endpoint in node.endpoints { + let dev_type = endpoint.device_type; + dev_type.to_tlv(tw, TagType::Anonymous)?; + } -impl ClusterType for DescriptorCluster { - fn base(&self) -> &Cluster { - &self.base - } - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base + tw.end_container() } - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::DeviceTypeList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_devtype_list(tag, tw) - })), - Some(Attributes::ServerList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_server_list(tag, tw) - })), - Some(Attributes::PartsList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_parts_list(tag, tw) - })), - Some(Attributes::ClientList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - self.encode_client_list(tag, tw) - })), - _ => { - error!("Attribute not supported: this shouldn't happen"); + fn encode_server_list( + &self, + node: &Node, + endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + for endpoint in node.endpoints { + if endpoint.id == endpoint_id { + for cluster in endpoint.clusters { + tw.u32(TagType::Anonymous, cluster.id as _)?; + } } } + + tw.end_container() + } + + fn encode_parts_list( + &self, + node: &Node, + endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + tw.start_array(tag)?; + + if endpoint_id == 0 { + // TODO: If endpoint is another than 0, need to figure out what to do + for endpoint in node.endpoints { + if endpoint.id != 0 { + tw.u16(TagType::Anonymous, endpoint.id)?; + } + } + } + + tw.end_container() + } + + fn encode_client_list( + &self, + _node: &Node, + _endpoint_id: u16, + tag: TagType, + tw: &mut TLVWriter, + ) -> Result<(), Error> { + // No Clients supported + tw.start_array(tag)?; + tw.end_container() + } +} + +impl Handler for DescriptorCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + DescriptorCluster::read(self, attr, encoder) + } +} + +impl NonBlockingHandler for DescriptorCluster {} + +impl ChangeNotifier<()> for DescriptorCluster { + fn consume_change(&mut self) -> Option<()> { + self.data_ver.consume_change(()) } } diff --git a/matter/src/error.rs b/matter/src/error.rs index 07cd681..3a54b2c 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -15,13 +15,14 @@ * limitations under the License. */ -use std::{ - array::TryFromSliceError, fmt, string::FromUtf8Error, sync::PoisonError, time::SystemTimeError, -}; +use alloc::string::FromUtf8Error; +use core::{array::TryFromSliceError, fmt}; use async_channel::{SendError, TryRecvError}; use log::error; +extern crate alloc; + #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { AttributeNotFound, @@ -31,6 +32,13 @@ pub enum Error { CommandNotFound, Duplicate, EndpointNotFound, + InvalidAction, + InvalidCommand, + InvalidDataType, + UnsupportedAccess, + ResourceExhausted, + Busy, + DataVersionMismatch, Crypto, TLSStack, MdnsError, @@ -71,6 +79,36 @@ pub enum Error { Utf8Fail, } +impl Error { + pub fn remap(self, matcher: F, to: Self) -> Self + where + F: FnOnce(&Self) -> bool, + { + if matcher(&self) { + to + } else { + self + } + } + + pub fn map_invalid(self, to: Self) -> Self { + self.remap(|e| matches!(e, Self::Invalid | Self::InvalidData), to) + } + + pub fn map_invalid_command(self) -> Self { + self.map_invalid(Error::InvalidCommand) + } + + pub fn map_invalid_action(self) -> Self { + self.map_invalid(Error::InvalidAction) + } + + pub fn map_invalid_data_type(self) -> Self { + self.map_invalid(Error::InvalidDataType) + } +} + +#[cfg(feature = "std")] impl From for Error { fn from(_e: std::io::Error) -> Self { // Keep things simple for now @@ -78,8 +116,9 @@ impl From for Error { } } -impl From> for Error { - fn from(_e: PoisonError) -> Self { +#[cfg(feature = "std")] +impl From> for Error { + fn from(_e: std::sync::PoisonError) -> Self { Self::RwLock } } @@ -107,8 +146,9 @@ impl From for Error { } } -impl From for Error { - fn from(_e: SystemTimeError) -> Self { +#[cfg(feature = "std")] +impl From for Error { + fn from(_e: std::time::SystemTimeError) -> Self { Self::SysTimeFail } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 1715db7..42c55fd 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -15,29 +15,33 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use core::fmt::Write; use byteorder::{BigEndian, ByteOrder, LittleEndian}; use log::{error, info}; -use owning_ref::RwLockReadGuardRef; use crate::{ cert::Cert, - crypto::{self, crypto_dummy::KeyPairDummy, hkdf_sha256, CryptoKeyPair, HmacSha256, KeyPair}, + crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::Error, group_keys::KeySet, - mdns::{self, Mdns}, - sys::{Psm, SysMdnsService}, + mdns::{MdnsMgr, ServiceMode}, + persist::Psm, tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, }; -const MAX_CERT_TLV_LEN: usize = 350; +const MAX_CERT_TLV_LEN: usize = 300; const COMPRESSED_FABRIC_ID_LEN: usize = 8; macro_rules! fb_key { - ($index:ident, $key:ident) => { - &format!("fb{}{}", $index, $key) - }; + ($index:ident, $key:ident, $buf:expr) => {{ + use core::fmt::Write; + + $buf = "".into(); + write!(&mut $buf, "fb{}{}", $index, $key).unwrap(); + + &$buf + }}; } const ST_VID: &str = "vid"; @@ -50,20 +54,6 @@ const ST_PBKEY: &str = "pubkey"; const ST_PRKEY: &str = "privkey"; #[allow(dead_code)] -pub struct Fabric { - node_id: u64, - fabric_id: u64, - vendor_id: u16, - key_pair: Box, - pub root_ca: Cert, - pub icac: Option, - pub noc: Cert, - pub ipk: KeySet, - label: String, - compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN], - mdns_service: Option, -} - #[derive(ToTLV)] #[tlvargs(lifetime = "'a", start = 1)] pub struct FabricDescriptor<'a> { @@ -77,6 +67,19 @@ pub struct FabricDescriptor<'a> { pub fab_idx: Option, } +pub struct Fabric { + node_id: u64, + fabric_id: u64, + vendor_id: u16, + key_pair: KeyPair, + pub root_ca: Cert, + pub icac: Option, + pub noc: Cert, + pub ipk: KeySet, + label: heapless::String<32>, + mdns_service_name: heapless::String<33>, +} + impl Fabric { pub fn new( key_pair: KeyPair, @@ -85,56 +88,43 @@ impl Fabric { noc: Cert, ipk: &[u8], vendor_id: u16, + label: &str, ) -> Result { let node_id = noc.get_node_id()?; let fabric_id = noc.get_fabric_id()?; - let mut f = Self { - node_id, - fabric_id, - vendor_id, - key_pair: Box::new(key_pair), - root_ca, - icac, - noc, - ipk: KeySet::default(), - compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], - label: "".into(), - mdns_service: None, - }; - Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?; - f.ipk = KeySet::new(ipk, &f.compressed_id)?; + let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; - let mut mdns_service_name = String::with_capacity(33); - for c in f.compressed_id { - mdns_service_name.push_str(&format!("{:02X}", c)); + Fabric::get_compressed_id(root_ca.get_pubkey(), fabric_id, &mut compressed_id)?; + let ipk = KeySet::new(ipk, &compressed_id)?; + + let mut mdns_service_name = heapless::String::<33>::new(); + for c in compressed_id { + let mut hex = heapless::String::<4>::new(); + write!(&mut hex, "{:02X}", c).unwrap(); + mdns_service_name.push_str(&hex).unwrap(); } - mdns_service_name.push('-'); + mdns_service_name.push('-').unwrap(); let mut node_id_be: [u8; 8] = [0; 8]; BigEndian::write_u64(&mut node_id_be, node_id); for c in node_id_be { - mdns_service_name.push_str(&format!("{:02X}", c)); + let mut hex = heapless::String::<4>::new(); + write!(&mut hex, "{:02X}", c).unwrap(); + mdns_service_name.push_str(&hex).unwrap(); } info!("MDNS Service Name: {}", mdns_service_name); - f.mdns_service = Some( - Mdns::get()?.publish_service(&mdns_service_name, mdns::ServiceMode::Commissioned)?, - ); - Ok(f) - } - pub fn dummy() -> Result { Ok(Self { - node_id: 0, - fabric_id: 0, - vendor_id: 0, - key_pair: Box::new(KeyPairDummy::new()?), - root_ca: Cert::default(), - icac: Some(Cert::default()), - noc: Cert::default(), - ipk: KeySet::default(), - label: "".into(), - compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], - mdns_service: None, + node_id, + fabric_id, + vendor_id, + key_pair, + root_ca, + icac, + noc, + ipk, + label: label.into(), + mdns_service_name, }) } @@ -195,164 +185,362 @@ impl Fabric { } } - fn rm_store(&self, index: usize, psm: &MutexGuard) { - psm.rm(fb_key!(index, ST_RCA)); - psm.rm(fb_key!(index, ST_ICA)); - psm.rm(fb_key!(index, ST_NOC)); - psm.rm(fb_key!(index, ST_IPK)); - psm.rm(fb_key!(index, ST_LBL)); - psm.rm(fb_key!(index, ST_PBKEY)); - psm.rm(fb_key!(index, ST_PRKEY)); - psm.rm(fb_key!(index, ST_VID)); - } + fn store(&self, index: usize, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); - fn store(&self, index: usize, psm: &MutexGuard) -> Result<(), Error> { - let mut key = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut key)?; - psm.set_kv_slice(fb_key!(index, ST_RCA), &key[..len])?; + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let len = self.root_ca.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])?; let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut key)? + icac.as_tlv(&mut buf)? } else { 0 }; - psm.set_kv_slice(fb_key!(index, ST_ICA), &key[..len])?; + psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])?; - let len = self.noc.as_tlv(&mut key)?; - psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?; - psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?; - psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?; + let len = self.noc.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; + psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; - let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut key)?; - let key = &key[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY), key)?; + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let len = self.key_pair.get_public_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key)?; - let mut key = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut key)?; - let key = &key[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?; + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let len = self.key_pair.get_private_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key)?; - psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?; + psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into())?; Ok(()) } - fn load(index: usize, psm: &MutexGuard) -> Result { - let mut root_ca = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_RCA), &mut root_ca)?; - let root_ca = Cert::new(root_ca.as_slice())?; + fn load(index: usize, psm: T) -> Result + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); - let mut icac = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_ICA), &mut icac)?; + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let root_ca = psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?; + let root_ca = Cert::new(root_ca)?; + + let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; let icac = if !icac.is_empty() { - Some(Cert::new(icac.as_slice())?) + Some(Cert::new(icac)?) } else { None }; - let mut noc = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_NOC), &mut noc)?; - let noc = Cert::new(noc.as_slice())?; + let noc = psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?; + let noc = Cert::new(noc)?; - let mut ipk = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?; + let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; + let label: heapless::String<32> = core::str::from_utf8(label) + .map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })? + .into(); - let mut label = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?; - let label = String::from_utf8(label).map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })?; + let ipk = psm.get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf)?; - let mut pub_key = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?; - let mut priv_key = Vec::new(); - psm.get_kv_slice(fb_key!(index, ST_PRKEY), &mut priv_key)?; - let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?; + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let pub_key = psm.get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf)?; - let mut vendor_id = 0; - psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?; + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let priv_key = psm.get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf)?; + let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - let f = Fabric::new( - keypair, - root_ca, - icac, - noc, - ipk.as_slice(), - vendor_id as u16, - ); - f.map(|mut f| { - f.label = label; - f - }) + let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb))?; + + Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) + } + + fn remove(index: usize, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + let mut _kb = heapless::String::<32>::new(); + + psm.remove(fb_key!(index, ST_RCA, _kb))?; + psm.remove(fb_key!(index, ST_ICA, _kb))?; + + psm.remove(fb_key!(index, ST_NOC, _kb))?; + + psm.remove(fb_key!(index, ST_LBL, _kb))?; + + psm.remove(fb_key!(index, ST_IPK, _kb))?; + + psm.remove(fb_key!(index, ST_PBKEY, _kb))?; + psm.remove(fb_key!(index, ST_PRKEY, _kb))?; + + psm.remove(fb_key!(index, ST_VID, _kb))?; + + Ok(()) + } + + #[cfg(feature = "nightly")] + async fn store_async(&self, index: usize, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let len = self.root_ca.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len]) + .await?; + + let len = if let Some(icac) = &self.icac { + icac.as_tlv(&mut buf)? + } else { + 0 + }; + psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len]) + .await?; + + let len = self.noc.as_tlv(&mut buf)?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len]) + .await?; + psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) + .await?; + psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes()) + .await?; + + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let len = self.key_pair.get_public_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key).await?; + + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let len = self.key_pair.get_private_key(&mut buf)?; + let key = &buf[..len]; + psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key).await?; + + psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into()) + .await?; + Ok(()) + } + + #[cfg(feature = "nightly")] + async fn load_async(index: usize, psm: T) -> Result + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + let mut buf = [0u8; MAX_CERT_TLV_LEN]; + let root_ca = psm + .get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) + .await?; + let root_ca = Cert::new(root_ca)?; + + let icac = psm + .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) + .await?; + let icac = if !icac.is_empty() { + Some(Cert::new(icac)?) + } else { + None + }; + + let noc = psm + .get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) + .await?; + let noc = Cert::new(noc)?; + + let label = psm + .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) + .await?; + let label: heapless::String<32> = core::str::from_utf8(label) + .map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })? + .into(); + + let ipk = psm + .get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf) + .await?; + + let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; + let pub_key = psm + .get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf) + .await?; + + let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; + let priv_key = psm + .get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf) + .await?; + let keypair = KeyPair::new_from_components(pub_key, priv_key)?; + + let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb)).await?; + + Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) + } + + #[cfg(feature = "nightly")] + async fn remove_async(index: usize, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + let mut _kb = heapless::String::<32>::new(); + + psm.remove(fb_key!(index, ST_RCA, _kb)).await?; + psm.remove(fb_key!(index, ST_ICA, _kb)).await?; + + psm.remove(fb_key!(index, ST_NOC, _kb)).await?; + + psm.remove(fb_key!(index, ST_LBL, _kb)).await?; + + psm.remove(fb_key!(index, ST_IPK, _kb)).await?; + + psm.remove(fb_key!(index, ST_PBKEY, _kb)).await?; + psm.remove(fb_key!(index, ST_PRKEY, _kb)).await?; + + psm.remove(fb_key!(index, ST_VID, _kb)).await?; + + Ok(()) } } pub const MAX_SUPPORTED_FABRICS: usize = 3; -#[derive(Default)] -pub struct FabricMgrInner { - // The outside world expects Fabric Index to be one more than the actual one - // since 0 is not allowed. Need to handle this cleanly somehow - pub fabrics: [Option; MAX_SUPPORTED_FABRICS], -} pub struct FabricMgr { - inner: RwLock, - psm: Arc>, + // The outside world expects Fabric Index to be one more than the actual one + // since 0 is not allowed. Need to handle this cleanly somehow + fabrics: [Option; MAX_SUPPORTED_FABRICS], + changed: bool, } impl FabricMgr { - pub fn new() -> Result { - let dummy_fabric = Fabric::dummy()?; - let mut mgr = FabricMgrInner::default(); - mgr.fabrics[0] = Some(dummy_fabric); - let mut fm = Self { - inner: RwLock::new(mgr), - psm: Psm::get()?, - }; - fm.load()?; - Ok(fm) - } + pub const fn new() -> Self { + const INIT: Option = None; - fn store(&self, index: usize, fabric: &Fabric) -> Result<(), Error> { - let psm = self.psm.lock().unwrap(); - fabric.store(index, &psm) - } - - fn load(&mut self) -> Result<(), Error> { - let mut mgr = self.inner.write()?; - let psm = self.psm.lock().unwrap(); - for i in 0..MAX_SUPPORTED_FABRICS { - let result = Fabric::load(i, &psm); - if let Ok(fabric) = result { - info!("Adding new fabric at index {}", i); - mgr.fabrics[i] = Some(fabric); - } + Self { + fabrics: [INIT; MAX_SUPPORTED_FABRICS], + changed: false, } + } + + pub fn store(&mut self, mut psm: T) -> Result<(), Error> + where + T: Psm, + { + if self.changed { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = self.fabrics[i].as_mut() { + info!("Storing fabric at index {}", i); + fabric.store(i, &mut psm)?; + } else { + let _ = Fabric::remove(i, &mut psm); + } + } + + self.changed = false; + } + Ok(()) } - pub fn add(&self, f: Fabric) -> Result { - let mut mgr = self.inner.write()?; - let index = mgr + pub fn load(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + where + T: Psm, + { + for i in 1..MAX_SUPPORTED_FABRICS { + let result = Fabric::load(i, &mut psm); + if let Ok(fabric) = result { + info!("Adding new fabric at index {}", i); + self.fabrics[i] = Some(fabric); + mdns_mgr.publish_service( + &self.fabrics[i].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + } else { + self.fabrics[i] = None; + } + } + + self.changed = false; + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + if self.changed { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = self.fabrics[i].as_mut() { + info!("Storing fabric at index {}", i); + fabric.store_async(i, &mut psm).await?; + } else { + let _ = Fabric::remove_async(i, &mut psm).await; + } + } + + self.changed = false; + } + + Ok(()) + } + + #[cfg(feature = "nightly")] + pub async fn load_async(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> + where + T: crate::persist::asynch::AsyncPsm, + { + for i in 1..MAX_SUPPORTED_FABRICS { + let result = Fabric::load_async(i, &mut psm).await; + if let Ok(fabric) = result { + info!("Adding new fabric at index {}", i); + self.fabrics[i] = Some(fabric); + mdns_mgr.publish_service( + &self.fabrics[i].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + } else { + self.fabrics[i] = None; + } + } + + self.changed = false; + + Ok(()) + } + + pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { + let index = self .fabrics .iter() + .skip(1) .position(|f| f.is_none()) .ok_or(Error::NoSpace)?; - self.store(index, &f)?; + self.fabrics[index] = Some(f); + mdns_mgr.publish_service( + &self.fabrics[index].as_ref().unwrap().mdns_service_name, + ServiceMode::Commissioned, + )?; + + self.changed = true; - mgr.fabrics[index] = Some(f); Ok(index as u8) } - pub fn remove(&self, fab_idx: u8) -> Result<(), Error> { - let fab_idx = fab_idx as usize; - let mut mgr = self.inner.write().unwrap(); - let psm = self.psm.lock().unwrap(); - if let Some(f) = &mgr.fabrics[fab_idx] { - f.rm_store(fab_idx, &psm); - mgr.fabrics[fab_idx] = None; + pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + if let Some(f) = self.fabrics[fab_idx as usize].take() { + mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; Ok(()) } else { Err(Error::NotFound) @@ -360,9 +548,8 @@ impl FabricMgr { } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { - let mgr = self.inner.read()?; - for i in 0..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = &self.fabrics[i] { if fabric.match_dest_id(random, target).is_ok() { return Ok(i); } @@ -371,17 +558,13 @@ impl FabricMgr { Err(Error::NotFound) } - pub fn get_fabric<'ret, 'me: 'ret>( - &'me self, - idx: usize, - ) -> Result>, Error> { - Ok(RwLockReadGuardRef::new(self.inner.read()?).map(|fm| &fm.fabrics[idx])) + pub fn get_fabric(&self, idx: usize) -> Result, Error> { + Ok(self.fabrics[idx].as_ref()) } pub fn is_empty(&self) -> bool { - let mgr = self.inner.read().unwrap(); for i in 1..MAX_SUPPORTED_FABRICS { - if mgr.fabrics[i].is_some() { + if self.fabrics[i].is_some() { return false; } } @@ -389,10 +572,9 @@ impl FabricMgr { } pub fn used_count(&self) -> usize { - let mgr = self.inner.read().unwrap(); let mut count = 0; for i in 1..MAX_SUPPORTED_FABRICS { - if mgr.fabrics[i].is_some() { + if self.fabrics[i].is_some() { count += 1; } } @@ -402,37 +584,30 @@ impl FabricMgr { // Parameters to T are the Fabric and its Fabric Index pub fn for_each(&self, mut f: T) -> Result<(), Error> where - T: FnMut(&Fabric, u8), + T: FnMut(&Fabric, u8) -> Result<(), Error>, { - let mgr = self.inner.read().unwrap(); for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { - f(fabric, i as u8) + if let Some(fabric) = &self.fabrics[i] { + f(fabric, i as u8)?; } } Ok(()) } - pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { + pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { let index = index as usize; - let mut mgr = self.inner.write()?; if !label.is_empty() { for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &mgr.fabrics[i] { + if let Some(fabric) = &self.fabrics[i] { if fabric.label == label { return Err(Error::Invalid); } } } } - if let Some(fabric) = &mut mgr.fabrics[index] { - let old = fabric.label.clone(); - fabric.label = label; - let psm = self.psm.lock().unwrap(); - if fabric.store(index, &psm).is_err() { - fabric.label = old; - return Err(Error::StdIoError); - } + if let Some(fabric) = &mut self.fabrics[index] { + fabric.label = label.into(); + self.changed = true; } Ok(()) } diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index 73c40e5..c4dfaaf 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -15,10 +15,13 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use alloc::sync::Arc; +use std::sync::{Mutex, Once}; use crate::{crypto, error::Error}; +extern crate alloc; + // This is just makeshift implementation for now, not used anywhere pub struct GroupKeys {} diff --git a/matter/src/interaction_model/command.rs b/matter/src/interaction_model/command.rs deleted file mode 100644 index 323c093..0000000 --- a/matter/src/interaction_model/command.rs +++ /dev/null @@ -1,88 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use super::core::IMStatusCode; -use super::core::OpCode; -use super::messages::ib; -use super::messages::msg; -use super::messages::msg::InvReq; -use super::InteractionModel; -use super::Transaction; -use crate::{ - error::*, - tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType}, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; -use log::error; - -#[macro_export] -macro_rules! cmd_enter { - ($e:expr) => {{ - use colored::Colorize; - info! {"{} {}", "Handling Command".cyan(), $e.cyan()} - }}; -} - -pub struct CommandReq<'a, 'b, 'c, 'd, 'e> { - pub cmd: ib::CmdPath, - pub data: TLVElement<'a>, - pub resp: &'a mut TLVWriter<'b, 'c>, - pub trans: &'a mut Transaction<'d, 'e>, -} - -impl InteractionModel { - pub fn handle_invoke_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - if InteractionModel::req_timeout_handled(trans, proto_tx)? { - return Ok(ResponseRequired::Yes); - } - - proto_tx.set_proto_opcode(OpCode::InvokeResponse as u8); - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let inv_req = InvReq::from_tlv(&root)?; - - let timed_tx = trans.get_timeout().map(|_| true); - let timed_request = inv_req.timed_request.filter(|a| *a); - // Either both should be None, or both should be Some(true) - if timed_tx != timed_request { - InteractionModel::create_status_response(proto_tx, IMStatusCode::TimedRequestMisMatch)?; - return Ok(ResponseRequired::Yes); - } - - tw.start_struct(TagType::Anonymous)?; - // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool( - TagType::Context(msg::InvRespTag::SupressResponse as u8), - false, - )?; - - self.consumer - .consume_invoke_cmd(&inv_req, trans, &mut tw) - .map_err(|e| { - error!("Error in handling command: {:?}", e); - print_tlv_list(rx_buf); - e - })?; - tw.end_container()?; - Ok(ResponseRequired::Yes) - } -} diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 1d548eb..8d8b4fb 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -15,212 +15,27 @@ * limitations under the License. */ -use std::time::{Duration, SystemTime}; +use core::time::Duration; use crate::{ + data_model::core::DataHandler, error::*, - interaction_model::messages::msg::StatusResp, - tlv::{self, get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::Exchange, - packet::Packet, - proto_demux::{self, ProtoCtx, ResponseRequired}, - session::SessionHandle, - }, + tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + transport::{exchange::ExchangeCtx, packet::Packet, proto_ctx::ProtoCtx, session::Session}, }; use colored::Colorize; use log::{error, info}; use num; use num_derive::FromPrimitive; -use super::InteractionModel; -use super::Transaction; -use super::TransactionState; -use super::{messages::msg::TimedReq, InteractionConsumer}; +use super::messages::msg::{self, InvReq, ReadReq, StatusResp, TimedReq, WriteReq}; -/* Handle messages related to the Interation Model - */ - -/* Interaction Model ID as per the Matter Spec */ -const PROTO_ID_INTERACTION_MODEL: usize = 0x01; - -#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] -pub enum OpCode { - Reserved = 0, - StatusResponse = 1, - ReadRequest = 2, - SubscribeRequest = 3, - SubscriptResponse = 4, - ReportData = 5, - WriteRequest = 6, - WriteResponse = 7, - InvokeRequest = 8, - InvokeResponse = 9, - TimedRequest = 10, -} - -impl<'a, 'b> Transaction<'a, 'b> { - pub fn new(session: &'a mut SessionHandle<'b>, exch: &'a mut Exchange) -> Self { - Self { - state: TransactionState::Ongoing, - session, - exch, - } - } - - /// Terminates the transaction, no communication (even ACKs) happens hence forth - pub fn terminate(&mut self) { - self.state = TransactionState::Terminate - } - - pub fn is_terminate(&self) -> bool { - self.state == TransactionState::Terminate - } - - /// Marks the transaction as completed from the application's perspective - pub fn complete(&mut self) { - self.state = TransactionState::Complete - } - - pub fn is_complete(&self) -> bool { - self.state == TransactionState::Complete - } - - pub fn set_timeout(&mut self, timeout: u64) { - self.exch - .set_data_time(SystemTime::now().checked_add(Duration::from_millis(timeout))); - } - - pub fn get_timeout(&mut self) -> Option { - self.exch.get_data_time() - } - - pub fn has_timed_out(&self) -> bool { - if let Some(timeout) = self.exch.get_data_time() { - if SystemTime::now() > timeout { - return true; - } - } - false - } -} - -impl InteractionModel { - pub fn new(consumer: Box) -> InteractionModel { - InteractionModel { consumer } - } - - pub fn handle_subscribe_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let (opcode, resp) = self.consumer.consume_subscribe(rx_buf, trans, &mut tw)?; - proto_tx.set_proto_opcode(opcode as u8); - Ok(resp) - } - - pub fn handle_status_resp( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let req = StatusResp::from_tlv(&root)?; - let (opcode, resp) = self.consumer.consume_status_report(&req, trans, &mut tw)?; - proto_tx.set_proto_opcode(opcode as u8); - Ok(resp) - } - - pub fn handle_timed_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - proto_tx.set_proto_opcode(OpCode::StatusResponse as u8); - - let root = get_root_node_struct(rx_buf)?; - let req = TimedReq::from_tlv(&root)?; - trans.set_timeout(req.timeout.into()); - - let status = StatusResp { - status: IMStatusCode::Success, - }; - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let _ = status.to_tlv(&mut tw, TagType::Anonymous); - Ok(ResponseRequired::Yes) - } - - /// Handle Request Timeouts - /// This API checks if a request was a timed request, and if so, and if the timeout has - /// expired, it will generate the appropriate response as expected - pub(super) fn req_timeout_handled( - trans: &mut Transaction, - proto_tx: &mut Packet, - ) -> Result { - if trans.has_timed_out() { - trans.complete(); - InteractionModel::create_status_response(proto_tx, IMStatusCode::Timeout)?; - Ok(true) - } else { - Ok(false) - } - } - - pub(super) fn create_status_response( - proto_tx: &mut Packet, - status: IMStatusCode, - ) -> Result<(), Error> { - proto_tx.set_proto_opcode(OpCode::StatusResponse as u8); - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let status = StatusResp { status }; - status.to_tlv(&mut tw, TagType::Anonymous) - } -} - -impl proto_demux::HandleProto for InteractionModel { - fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result { - let mut trans = Transaction::new(&mut ctx.exch_ctx.sess, ctx.exch_ctx.exch); - let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; - ctx.tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); - - let buf = ctx.rx.as_borrow_slice(); - info!("{} {:?}", "Received command".cyan(), proto_opcode); - tlv::print_tlv_list(buf); - let result = match proto_opcode { - OpCode::InvokeRequest => self.handle_invoke_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::ReadRequest => self.handle_read_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::WriteRequest => self.handle_write_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::TimedRequest => self.handle_timed_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?, - OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?, - _ => { - error!("Opcode Not Handled: {:?}", proto_opcode); - return Err(Error::InvalidOpcode); - } - }; - - if result == ResponseRequired::Yes { - info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_borrow_slice()); - } - if trans.is_terminate() { - ctx.exch_ctx.exch.terminate(); - } else if trans.is_complete() { - ctx.exch_ctx.exch.close(); - } - Ok(result) - } - - fn get_proto_id(&self) -> usize { - PROTO_ID_INTERACTION_MODEL - } +#[macro_export] +macro_rules! cmd_enter { + ($e:expr) => {{ + use colored::Colorize; + info! {"{} {}", "Handling Command".cyan(), $e.cyan()} + }}; } #[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)] @@ -260,6 +75,12 @@ impl From for IMStatusCode { Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, Error::CommandNotFound => IMStatusCode::UnsupportedCommand, + Error::InvalidAction => IMStatusCode::InvalidAction, + Error::InvalidCommand => IMStatusCode::InvalidCommand, + Error::UnsupportedAccess => IMStatusCode::UnsupportedAccess, + Error::Busy => IMStatusCode::Busy, + Error::DataVersionMismatch => IMStatusCode::DataVersionMismatch, + Error::ResourceExhausted => IMStatusCode::ResourceExhausted, _ => IMStatusCode::Failure, } } @@ -276,3 +97,418 @@ impl ToTLV for IMStatusCode { tw.u16(tag_type, *self as u16) } } + +#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] +pub enum OpCode { + Reserved = 0, + StatusResponse = 1, + ReadRequest = 2, + SubscribeRequest = 3, + SubscriptResponse = 4, + ReportData = 5, + WriteRequest = 6, + WriteResponse = 7, + InvokeRequest = 8, + InvokeResponse = 9, + TimedRequest = 10, +} + +#[derive(PartialEq)] +pub enum TransactionState { + Ongoing, + Complete, + Terminate, +} +pub struct Transaction<'a, 'b> { + state: TransactionState, + ctx: &'a mut ExchangeCtx<'b>, +} + +impl<'a, 'b> Transaction<'a, 'b> { + pub fn new(ctx: &'a mut ExchangeCtx<'b>) -> Self { + Self { + state: TransactionState::Ongoing, + ctx, + } + } + + pub fn session(&self) -> &Session { + self.ctx.sess.session() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.ctx.sess.session_mut() + } + + /// Terminates the transaction, no communication (even ACKs) happens hence forth + pub fn terminate(&mut self) { + self.state = TransactionState::Terminate + } + + pub fn is_terminate(&self) -> bool { + self.state == TransactionState::Terminate + } + /// Marks the transaction as completed from the application's perspective + pub fn complete(&mut self) { + self.state = TransactionState::Complete + } + + pub fn is_complete(&self) -> bool { + self.state == TransactionState::Complete + } + + pub fn set_timeout(&mut self, timeout: u64) { + let now = (self.ctx.epoch)(); + + self.ctx + .exch + .set_data_time(now.checked_add(Duration::from_millis(timeout))); + } + + pub fn get_timeout(&mut self) -> Option { + self.ctx.exch.get_data_time() + } + + pub fn has_timed_out(&self) -> bool { + if let Some(timeout) = self.ctx.exch.get_data_time() { + if (self.ctx.epoch)() > timeout { + return true; + } + } + false + } +} + +/* Interaction Model ID as per the Matter Spec */ +const PROTO_ID_INTERACTION_MODEL: usize = 0x01; + +pub enum Interaction<'a> { + Read(ReadReq<'a>), + Write(WriteReq<'a>), + Invoke(InvReq<'a>), + Timed(TimedReq), +} + +impl<'a> Interaction<'a> { + pub fn new(rx: &'a Packet) -> Result { + let opcode: OpCode = + num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; + + let rx_data = rx.as_slice(); + + info!("{} {:?}", "Received command".cyan(), opcode); + print_tlv_list(rx_data); + + match opcode { + OpCode::ReadRequest => Ok(Self::Read(ReadReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::WriteRequest => Ok(Self::Write(WriteReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::InvokeRequest => Ok(Self::Invoke(InvReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + OpCode::TimedRequest => Ok(Self::Timed(TimedReq::from_tlv(&get_root_node_struct( + rx_data, + )?)?)), + // TODO + // OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?, + // OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?, + _ => { + error!("Opcode Not Handled: {:?}", opcode); + Err(Error::InvalidOpcode) + } + } + } + + pub fn initiate_tx( + &self, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + let reply = match self { + Self::Read(request) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::ReportData as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + if request.attr_requests.is_some() { + tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?; + } + + false + } + Interaction::Write(_) => { + if transaction.has_timed_out() { + Self::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + true + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::WriteResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?; + + false + } + } + Interaction::Invoke(request) => { + if transaction.has_timed_out() { + Self::create_status_response(tx, IMStatusCode::Timeout)?; + + transaction.complete(); + transaction.ctx.exch.close(); + + true + } else { + let timed_tx = transaction.get_timeout().map(|_| true); + let timed_request = request.timed_request.filter(|a| *a); + + // Either both should be None, or both should be Some(true) + if timed_tx != timed_request { + Self::create_status_response(tx, IMStatusCode::TimedRequestMisMatch)?; + + true + } else { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::InvokeResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.start_struct(TagType::Anonymous)?; + + // Suppress Response -> TODO: Need to revisit this for cases where we send a command back + tw.bool( + TagType::Context(msg::InvRespTag::SupressResponse as u8), + false, + )?; + + if request.inv_requests.is_some() { + tw.start_array(TagType::Context( + msg::InvRespTag::InvokeResponses as u8, + ))?; + } + + false + } + } + } + Interaction::Timed(request) => { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + transaction.set_timeout(request.timeout.into()); + + let status = StatusResp { + status: IMStatusCode::Success, + }; + + status.to_tlv(&mut tw, TagType::Anonymous)?; + + true + } + }; + + Ok(!reply) + } + + pub fn complete_tx( + &self, + tx: &mut Packet, + transaction: &mut Transaction, + ) -> Result { + let reply = match self { + Self::Read(request) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + if request.attr_requests.is_some() { + tw.end_container()?; + } + + // Suppress response always true for read interaction + tw.bool( + TagType::Context(msg::ReportDataTag::SupressResponse as u8), + true, + )?; + + tw.end_container()?; + + transaction.complete(); + + true + } + Self::Write(request) => { + let suppress = request.supress_response.unwrap_or_default(); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + tw.end_container()?; + tw.end_container()?; + + transaction.complete(); + + if suppress { + error!("Supress response is set, is this the expected handling?"); + false + } else { + true + } + } + Self::Invoke(request) => { + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + if request.inv_requests.is_some() { + tw.end_container()?; + } + + tw.end_container()?; + + true + } + Self::Timed(_) => false, + }; + + if reply { + info!("Sending response"); + print_tlv_list(tx.as_slice()); + } + + if transaction.is_terminate() { + transaction.ctx.exch.terminate(); + } else if transaction.is_complete() { + transaction.ctx.exch.close(); + } + + Ok(true) + } + + fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_opcode(OpCode::StatusResponse as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + + let status = StatusResp { status }; + status.to_tlv(&mut tw, TagType::Anonymous) + } +} + +pub trait InteractionHandler { + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error>; +} + +impl InteractionHandler for &mut T +where + T: InteractionHandler, +{ + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + (**self).handle(ctx) + } +} + +pub struct InteractionModel(pub T); + +impl InteractionModel +where + T: DataHandler, +{ + pub fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + let interaction = Interaction::new(ctx.rx)?; + let mut transaction = Transaction::new(&mut ctx.exch_ctx); + + let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { + self.0.handle(&interaction, ctx.tx, &mut transaction)?; + interaction.complete_tx(ctx.tx, &mut transaction)? + } else { + true + }; + + Ok(reply.then_some(ctx.tx.as_slice())) + } +} + +#[cfg(feature = "nightly")] +impl InteractionModel +where + T: crate::data_model::core::asynch::AsyncDataHandler, +{ + pub async fn handle_async<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + let interaction = Interaction::new(ctx.rx)?; + let mut transaction = Transaction::new(&mut ctx.exch_ctx); + + let reply = if interaction.initiate_tx(ctx.tx, &mut transaction)? { + self.0 + .handle(&interaction, ctx.tx, &mut transaction) + .await?; + interaction.complete_tx(ctx.tx, &mut transaction)? + } else { + true + }; + + Ok(reply.then_some(ctx.tx.as_slice())) + } +} + +impl InteractionHandler for InteractionModel +where + T: DataHandler, +{ + fn handle<'a>(&mut self, ctx: &'a mut ProtoCtx) -> Result, Error> { + InteractionModel::handle(self, ctx) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::{ + data_model::core::asynch::AsyncDataHandler, error::Error, transport::proto_ctx::ProtoCtx, + }; + + use super::InteractionModel; + + pub trait AsyncInteractionHandler { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error>; + } + + impl AsyncInteractionHandler for &mut T + where + T: AsyncInteractionHandler, + { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + (**self).handle(ctx).await + } + } + + impl AsyncInteractionHandler for InteractionModel + where + T: AsyncDataHandler, + { + async fn handle<'a>( + &mut self, + ctx: &'a mut ProtoCtx<'_, '_>, + ) -> Result, Error> { + InteractionModel::handle_async(self, ctx).await + } + } +} diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index aac30f7..19e29f1 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -160,13 +160,6 @@ pub mod msg { pub inv_requests: Option>>, } - // This enum is helpful when we are constructing the response - // step by step in incremental manner - pub enum InvRespTag { - SupressResponse = 0, - InvokeResponses = 1, - } - #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct InvResp<'a> { @@ -174,7 +167,14 @@ pub mod msg { pub inv_responses: Option>>, } - #[derive(Default, ToTLV, FromTLV)] + // This enum is helpful when we are constructing the response + // step by step in incremental manner + pub enum InvRespTag { + SupressResponse = 0, + InvokeResponses = 1, + } + + #[derive(Default, ToTLV, FromTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct ReadReq<'a> { pub attr_requests: Option>, @@ -198,17 +198,17 @@ pub mod msg { } } - #[derive(ToTLV, FromTLV)] - #[tlvargs(lifetime = "'b")] - pub struct WriteReq<'a, 'b> { + #[derive(FromTLV, ToTLV, Debug)] + #[tlvargs(lifetime = "'a")] + pub struct WriteReq<'a> { pub supress_response: Option, timed_request: Option, - pub write_requests: TLVArray<'a, AttrData<'b>>, + pub write_requests: TLVArray<'a, AttrData<'a>>, more_chunked: Option, } - impl<'a, 'b> WriteReq<'a, 'b> { - pub fn new(supress_response: bool, write_requests: &'a [AttrData<'b>]) -> Self { + impl<'a> WriteReq<'a> { + pub fn new(supress_response: bool, write_requests: &'a [AttrData<'a>]) -> Self { let mut w = Self { supress_response: None, write_requests: TLVArray::new(write_requests), @@ -223,7 +223,7 @@ pub mod msg { } // Report Data - #[derive(FromTLV, ToTLV)] + #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct ReportDataMsg<'a> { pub subscription_id: Option, @@ -243,7 +243,7 @@ pub mod msg { } // Write Response - #[derive(ToTLV, FromTLV)] + #[derive(ToTLV, FromTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct WriteResp<'a> { pub write_responses: TLVArray<'a, AttrStatus>, @@ -255,10 +255,10 @@ pub mod msg { } pub mod ib { - use std::fmt::Debug; + use core::fmt::Debug; use crate::{ - data_model::objects::{AttrDetails, AttrId, ClusterId, EncodeValue, EndptId}, + data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId}, error::Error, interaction_model::core::IMStatusCode, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, @@ -276,18 +276,6 @@ pub mod ib { } impl<'a> InvResp<'a> { - pub fn cmd_new( - endpoint: EndptId, - cluster: ClusterId, - cmd: u16, - data: EncodeValue<'a>, - ) -> Self { - Self::Cmd(CmdData::new( - CmdPath::new(Some(endpoint), Some(cluster), Some(cmd)), - data, - )) - } - pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self { Self::Status(CmdStatus { path: cmd_path, @@ -296,6 +284,23 @@ pub mod ib { } } + impl<'a> From> for InvResp<'a> { + fn from(value: CmdData<'a>) -> Self { + Self::Cmd(value) + } + } + + pub enum InvRespTag { + Cmd = 0, + Status = 1, + } + + impl<'a> From for InvResp<'a> { + fn from(value: CmdStatus) -> Self { + Self::Status(value) + } + } + #[derive(FromTLV, ToTLV, Copy, Clone, PartialEq, Debug)] pub struct CmdStatus { path: CmdPath, @@ -327,6 +332,11 @@ pub mod ib { } } + pub enum CmdDataTag { + Path = 0, + Data = 1, + } + // Status #[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] pub struct Status { @@ -352,10 +362,6 @@ pub mod ib { } impl<'a> AttrResp<'a> { - pub fn new(data_ver: u32, path: &AttrPath, data: EncodeValue<'a>) -> Self { - AttrResp::Data(AttrData::new(Some(data_ver), *path, data)) - } - pub fn unwrap_data(self) -> AttrData<'a> { match self { AttrResp::Data(d) => d, @@ -366,6 +372,23 @@ pub mod ib { } } + impl<'a> From> for AttrResp<'a> { + fn from(value: AttrData<'a>) -> Self { + Self::Data(value) + } + } + + impl<'a> From for AttrResp<'a> { + fn from(value: AttrStatus) -> Self { + Self::Status(value) + } + } + + pub enum AttrRespTag { + Status = 0, + Data = 1, + } + // Attribute Data #[derive(Clone, Copy, PartialEq, FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] @@ -385,6 +408,12 @@ pub mod ib { } } + pub enum AttrDataTag { + DataVer = 0, + Path = 1, + Data = 2, + } + #[derive(Debug)] /// Operations on an Interaction Model List pub enum ListOperation { @@ -399,13 +428,9 @@ pub mod ib { } /// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function - pub fn attr_list_write( - attr: &AttrDetails, - data: &TLVElement, - mut f: F, - ) -> Result<(), IMStatusCode> + pub fn attr_list_write(attr: &AttrDetails, data: &TLVElement, mut f: F) -> Result<(), Error> where - F: FnMut(ListOperation, &TLVElement) -> Result<(), IMStatusCode>, + F: FnMut(ListOperation, &TLVElement) -> Result<(), Error>, { if let Some(Nullable::NotNull(index)) = attr.list_index { // If list index is valid, @@ -499,13 +524,13 @@ pub mod ib { pub fn new( endpoint: Option, cluster: Option, - command: Option, + command: Option, ) -> Self { Self { path: GenericPath { endpoint, cluster, - leaf: command.map(|a| a as u32), + leaf: command, }, } } @@ -532,20 +557,20 @@ pub mod ib { } } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct ClusterPath { pub node: Option, pub endpoint: EndptId, pub cluster: ClusterId, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct DataVersionFilter { pub path: ClusterPath, pub data_ver: u32, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] #[tlvargs(datatype = "list")] pub struct EventPath { pub node: Option, @@ -555,7 +580,7 @@ pub mod ib { pub is_urgent: Option, } - #[derive(FromTLV, ToTLV, Copy, Clone)] + #[derive(FromTLV, ToTLV, Copy, Clone, Debug)] pub struct EventFilter { pub node: Option, pub event_min: Option, diff --git a/matter/src/interaction_model/mod.rs b/matter/src/interaction_model/mod.rs index 2caf55f..22e0ee9 100644 --- a/matter/src/interaction_model/mod.rs +++ b/matter/src/interaction_model/mod.rs @@ -15,73 +15,5 @@ * limitations under the License. */ -use crate::{ - error::Error, - tlv::TLVWriter, - transport::{exchange::Exchange, proto_demux::ResponseRequired, session::SessionHandle}, -}; - -use self::{ - core::OpCode, - messages::msg::{InvReq, StatusResp, WriteReq}, -}; - -#[derive(PartialEq)] -pub enum TransactionState { - Ongoing, - Complete, - Terminate, -} -pub struct Transaction<'a, 'b> { - pub state: TransactionState, - pub session: &'a mut SessionHandle<'b>, - pub exch: &'a mut Exchange, -} - -pub trait InteractionConsumer { - fn consume_invoke_cmd( - &self, - req: &InvReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_read_attr( - &self, - // TODO: This handling is different from the other APIs here, identify - // consistent options for this trait - req: &[u8], - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_write_attr( - &self, - req: &WriteReq, - trans: &mut Transaction, - tw: &mut TLVWriter, - ) -> Result<(), Error>; - - fn consume_status_report( - &self, - _req: &StatusResp, - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error>; - - fn consume_subscribe( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error>; -} - -pub struct InteractionModel { - consumer: Box, -} -pub mod command; pub mod core; pub mod messages; -pub mod read; -pub mod write; diff --git a/matter/src/interaction_model/read.rs b/matter/src/interaction_model/read.rs deleted file mode 100644 index 0985eea..0000000 --- a/matter/src/interaction_model/read.rs +++ /dev/null @@ -1,42 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::{ - error::Error, - interaction_model::core::OpCode, - tlv::TLVWriter, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; - -use super::{InteractionModel, Transaction}; - -impl InteractionModel { - pub fn handle_read_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - proto_tx.set_proto_opcode(OpCode::ReportData as u8); - let proto_tx_wb = proto_tx.get_writebuf()?; - let mut tw = TLVWriter::new(proto_tx_wb); - - self.consumer.consume_read_attr(rx_buf, trans, &mut tw)?; - - Ok(ResponseRequired::Yes) - } -} diff --git a/matter/src/interaction_model/write.rs b/matter/src/interaction_model/write.rs deleted file mode 100644 index 48b7903..0000000 --- a/matter/src/interaction_model/write.rs +++ /dev/null @@ -1,58 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use log::error; - -use crate::{ - error::Error, - tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}, - transport::{packet::Packet, proto_demux::ResponseRequired}, -}; - -use super::{core::OpCode, messages::msg::WriteReq, InteractionModel, Transaction}; - -impl InteractionModel { - pub fn handle_write_req( - &mut self, - trans: &mut Transaction, - rx_buf: &[u8], - proto_tx: &mut Packet, - ) -> Result { - if InteractionModel::req_timeout_handled(trans, proto_tx)? { - return Ok(ResponseRequired::Yes); - } - proto_tx.set_proto_opcode(OpCode::WriteResponse as u8); - - let mut tw = TLVWriter::new(proto_tx.get_writebuf()?); - let root = get_root_node_struct(rx_buf)?; - let write_req = WriteReq::from_tlv(&root)?; - let supress_response = write_req.supress_response.unwrap_or_default(); - - tw.start_struct(TagType::Anonymous)?; - self.consumer - .consume_write_attr(&write_req, trans, &mut tw)?; - tw.end_container()?; - - trans.complete(); - if supress_response { - error!("Supress response is set, is this the expected handling?"); - Ok(ResponseRequired::No) - } else { - Ok(ResponseRequired::Yes) - } - } -} diff --git a/matter/src/lib.rs b/matter/src/lib.rs index 9f03ac5..0d99cdb 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -23,7 +23,7 @@ //! Currently Ethernet based transport is supported. //! //! # Examples -//! ``` +//! TODO: Fix once new API has stabilized a bit //! use matter::{Matter, CommissioningData}; //! use matter::data_model::device_types::device_type_add_on_off_light; //! use matter::data_model::cluster_basic_information::BasicInfoConfig; @@ -65,8 +65,11 @@ //! } //! // Start the Matter Daemon //! // matter.start_daemon().unwrap(); -//! ``` +//! //! Start off exploring by going to the [Matter] object. +#![cfg_attr(not(feature = "std"), no_std)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(incomplete_features))] pub mod acl; pub mod cert; @@ -80,6 +83,7 @@ pub mod group_keys; pub mod interaction_model; pub mod mdns; pub mod pairing; +pub mod persist; pub mod secure_channel; pub mod sys; pub mod tlv; diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index f28bea0..71be231 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -15,34 +15,58 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use core::fmt::Write; -use crate::{ - error::Error, - sys::{sys_publish_service, SysMdnsService}, - transport::udp::MATTER_PORT, -}; +use crate::error::Error; -#[derive(Default)] -/// The mDNS service handler -pub struct MdnsInner { - /// Vendor ID - vid: u16, - /// Product ID - pid: u16, - /// Device name - device_name: String, +pub trait Mdns { + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error>; + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error>; } -pub struct Mdns { - inner: Mutex, +impl Mdns for &mut T +where + T: Mdns, +{ + fn add( + &mut self, + name: &str, + service_type: &str, + port: u16, + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + (**self).add(name, service_type, port, txt_kvs) + } + + fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { + (**self).remove(name, service_type, port) + } } -const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; -const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; +pub struct DummyMdns; -static mut G_MDNS: Option> = None; -static INIT: Once = Once::new(); +impl Mdns for DummyMdns { + fn add( + &mut self, + _name: &str, + _service_type: &str, + _port: u16, + _txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + Ok(()) + } + + fn remove(&mut self, _name: &str, _service_type: &str, _port: u16) -> Result<(), Error> { + Ok(()) + } +} pub enum ServiceMode { /// The commissioned state @@ -51,68 +75,108 @@ pub enum ServiceMode { Commissionable(u16), } -impl Mdns { - fn new() -> Self { +/// The mDNS service handler +pub struct MdnsMgr<'a> { + /// Vendor ID + vid: u16, + /// Product ID + pid: u16, + /// Device name + device_name: heapless::String<32>, + /// Matter port + matter_port: u16, + /// mDns service + mdns: &'a mut dyn Mdns, +} + +impl<'a> MdnsMgr<'a> { + pub fn new( + vid: u16, + pid: u16, + device_name: &str, + matter_port: u16, + mdns: &'a mut dyn Mdns, + ) -> Self { Self { - inner: Mutex::new(MdnsInner { - ..Default::default() - }), + vid, + pid, + device_name: device_name.chars().take(32).collect(), + matter_port, + mdns, } } - /// Get a handle to the globally unique mDNS instance - pub fn get() -> Result, Error> { - unsafe { - INIT.call_once(|| { - G_MDNS = Some(Arc::new(Mdns::new())); - }); - Ok(G_MDNS.as_ref().ok_or(Error::Invalid)?.clone()) - } - } - - /// Set mDNS service specific values - /// Values like vid, pid, discriminator etc - // TODO: More things like device-type etc can be added here - pub fn set_values(&self, vid: u16, pid: u16, device_name: &str) { - let mut inner = self.inner.lock().unwrap(); - inner.vid = vid; - inner.pid = pid; - inner.device_name = device_name.chars().take(32).collect(); - } - - /// Publish a mDNS service + /// Publish an mDNS service /// name - is the service name (comma separated subtypes may follow) /// mode - the current service mode #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result { + pub fn publish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { - ServiceMode::Commissioned => { - sys_publish_service(name, "_matter._tcp", MATTER_PORT, &[]) - } + ServiceMode::Commissioned => self.mdns.add(name, "_matter._tcp", self.matter_port, &[]), ServiceMode::Commissionable(discriminator) => { - let inner = self.inner.lock().unwrap(); - let short = compute_short_discriminator(discriminator); - let serv_type = format!("_matterc._udp,_S{},_L{}", short, discriminator); + let discriminator_str = Self::get_discriminator_str(discriminator); + + let serv_type = self.get_service_type(discriminator); + let vp = self.get_vp(); - let str_discriminator = format!("{}", discriminator); let txt_kvs = [ - ["D", &str_discriminator], - ["CM", "1"], - ["DN", &inner.device_name], - ["VP", &format!("{}+{}", inner.vid, inner.pid)], - ["SII", "5000"], /* Sleepy Idle Interval */ - ["SAI", "300"], /* Sleepy Active Interval */ - ["PH", "33"], /* Pairing Hint */ - ["PI", ""], /* Pairing Instruction */ + ("D", discriminator_str.as_str()), + ("CM", "1"), + ("DN", self.device_name.as_str()), + ("VP", &vp), + ("SII", "5000"), /* Sleepy Idle Interval */ + ("SAI", "300"), /* Sleepy Active Interval */ + ("PH", "33"), /* Pairing Hint */ + ("PI", ""), /* Pairing Instruction */ ]; - sys_publish_service(name, &serv_type, MATTER_PORT, &txt_kvs) + self.mdns.add(name, &serv_type, self.matter_port, &txt_kvs) } } } -} -fn compute_short_discriminator(discriminator: u16) -> u16 { - (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT + pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + match mode { + ServiceMode::Commissioned => self.mdns.remove(name, "_matter._tcp", self.matter_port), + ServiceMode::Commissionable(discriminator) => { + let serv_type = self.get_service_type(discriminator); + + self.mdns.remove(name, &serv_type, self.matter_port) + } + } + } + + fn get_service_type(&self, discriminator: u16) -> heapless::String<32> { + let short = Self::compute_short_discriminator(discriminator); + let mut serv_type = heapless::String::new(); + + write!( + &mut serv_type, + "_matterc._udp,_S{},_L{}", + short, discriminator + ) + .unwrap(); + + serv_type + } + + fn get_vp(&self) -> heapless::String<11> { + let mut vp = heapless::String::new(); + + write!(&mut vp, "{}+{}", self.vid, self.pid).unwrap(); + + vp + } + + fn get_discriminator_str(discriminator: u16) -> heapless::String<5> { + discriminator.into() + } + + fn compute_short_discriminator(discriminator: u16) -> u16 { + const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; + const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; + + (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT + } } #[cfg(test)] @@ -122,11 +186,11 @@ mod tests { #[test] fn can_compute_short_discriminator() { let discriminator: u16 = 0b0000_1111_0000_0000; - let short = compute_short_discriminator(discriminator); + let short = MdnsMgr::compute_short_discriminator(discriminator); assert_eq!(short, 0b1111); let discriminator: u16 = 840; - let short = compute_short_discriminator(discriminator); + let short = MdnsMgr::compute_short_discriminator(discriminator); assert_eq!(short, 3); } } diff --git a/matter/src/pairing/code.rs b/matter/src/pairing/code.rs index 83d90f3..16e4fea 100644 --- a/matter/src/pairing/code.rs +++ b/matter/src/pairing/code.rs @@ -15,56 +15,66 @@ * limitations under the License. */ +use core::fmt::Write; + use super::*; -pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> String { +pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> { // 0: no Vendor ID and Product ID present in Manual Pairing Code const VID_PID_PRESENT: u8 = 0; let passwd = passwd_from_comm_data(comm_data); let CommissioningData { discriminator, .. } = comm_data; - let mut digits = String::new(); - digits.push_str(&((VID_PID_PRESENT << 2) | (discriminator >> 10) as u8).to_string()); - digits.push_str(&format!( - "{:0>5}", - ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16 - )); - digits.push_str(&format!("{:0>4}", passwd >> 14)); + let mut digits = heapless::String::<32>::new(); + write!( + &mut digits, + "{}{:0>5}{:0>4}", + (VID_PID_PRESENT << 2) | (discriminator >> 10) as u8, + ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16, + passwd >> 14 + ) + .unwrap(); - let check_digit = digits.calculate_verhoeff_check_digit(); - digits.push_str(&check_digit.to_string()); + let mut final_digits = heapless::String::<32>::new(); + write!( + &mut final_digits, + "{}{}", + digits, + digits.calculate_verhoeff_check_digit() + ) + .unwrap(); - digits + final_digits } pub(super) fn pretty_print_pairing_code(pairing_code: &str) { assert!(pairing_code.len() == 11); - let mut pretty = String::new(); - pretty.push_str(&pairing_code[..4]); - pretty.push('-'); - pretty.push_str(&pairing_code[4..8]); - pretty.push('-'); - pretty.push_str(&pairing_code[8..]); + let mut pretty = heapless::String::<32>::new(); + pretty.push_str(&pairing_code[..4]).unwrap(); + pretty.push('-').unwrap(); + pretty.push_str(&pairing_code[4..8]).unwrap(); + pretty.push('-').unwrap(); + pretty.push_str(&pairing_code[8..]).unwrap(); info!("Pairing Code: {}", pretty); } #[cfg(test)] mod tests { use super::*; - use crate::secure_channel::spake2p::VerifierData; + use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand}; #[test] fn can_compute_pairing_code() { let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(123456), + verifier: VerifierData::new_with_pw(123456, dummy_rand), discriminator: 250, }; let pairing_code = compute_pairing_code(&comm_data); assert_eq!(pairing_code, "00876800071"); let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(34567890), + verifier: VerifierData::new_with_pw(34567890, dummy_rand), discriminator: 2976, }; let pairing_code = compute_pairing_code(&comm_data); diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index 0a3509d..f1d844a 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use std::collections::BTreeMap; +use heapless::FnvIndexMap; use crate::{ tlv::{TLVWriter, TagType}, @@ -55,7 +55,7 @@ const SERIAL_NUMBER_TAG: u8 = 0x00; // const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04; pub enum QRCodeInfoType { - String(String), + String(heapless::String<128>), // TODO: Big enough? Int32(i32), Int64(i64), UInt32(u32), @@ -63,7 +63,7 @@ pub enum QRCodeInfoType { } pub enum SerialNumber { - String(String), + String(heapless::String<128>), UInt32(u32), } @@ -78,10 +78,10 @@ pub struct QrSetupPayload<'data> { version: u8, flow_type: CommissionningFlowType, discovery_capabilities: DiscoveryCapabilities, - dev_det: &'data BasicInfoConfig, + dev_det: &'data BasicInfoConfig<'data>, comm_data: &'data CommissioningData, // we use a BTreeMap to keep the order of the optional data stable - optional_data: BTreeMap, + optional_data: heapless::FnvIndexMap, } impl<'data> QrSetupPayload<'data> { @@ -98,11 +98,11 @@ impl<'data> QrSetupPayload<'data> { discovery_capabilities, dev_det, comm_data, - optional_data: BTreeMap::new(), + optional_data: FnvIndexMap::new(), }; if !dev_det.serial_no.is_empty() { - result.add_serial_number(SerialNumber::String(dev_det.serial_no.clone())); + result.add_serial_number(SerialNumber::String(dev_det.serial_no.into())); } result @@ -137,7 +137,9 @@ impl<'data> QrSetupPayload<'data> { } self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }); + .insert(tag, OptionalQRCodeInfo { tag, data }) + .map_err(|_| Error::NoSpace)?; + Ok(()) } @@ -155,11 +157,13 @@ impl<'data> QrSetupPayload<'data> { } self.optional_data - .insert(tag, OptionalQRCodeInfo { tag, data }); + .insert(tag, OptionalQRCodeInfo { tag, data }) + .map_err(|_| Error::NoSpace)?; + Ok(()) } - pub fn get_all_optional_data(&self) -> &BTreeMap { + pub fn get_all_optional_data(&self) -> &FnvIndexMap { &self.optional_data } @@ -388,7 +392,7 @@ fn generate_tlv_from_optional_data( ) -> Result<(), Error> { let size_needed = tlv_data.max_data_length_in_bytes as usize; let mut tlv_buffer = vec![0u8; size_needed]; - let mut wb = WriteBuf::new(&mut tlv_buffer, size_needed); + let mut wb = WriteBuf::new(&mut tlv_buffer); let mut tw = TLVWriter::new(&mut wb); tw.start_struct(TagType::Anonymous)?; @@ -532,16 +536,15 @@ fn is_common_tag(tag: u8) -> bool { #[cfg(test)] mod tests { - use super::*; - use crate::secure_channel::spake2p::VerifierData; + use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand}; #[test] fn can_base38_encode() { const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00"; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(34567890), + verifier: VerifierData::new_with_pw(34567890, dummy_rand), discriminator: 2976, }; let dev_det = BasicInfoConfig { @@ -561,13 +564,13 @@ mod tests { const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40"; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(20202021), + verifier: VerifierData::new_with_pw(20202021, dummy_rand), discriminator: 3840, }; let dev_det = BasicInfoConfig { vid: 65521, pid: 32769, - serial_no: "1234567890".to_string(), + serial_no: "1234567890", ..Default::default() }; @@ -588,13 +591,13 @@ mod tests { const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550; let comm_data = CommissioningData { - verifier: VerifierData::new_with_pw(20202021), + verifier: VerifierData::new_with_pw(20202021, dummy_rand), discriminator: 3840, }; let dev_det = BasicInfoConfig { vid: 65521, pid: 32769, - serial_no: "1234567890".to_string(), + serial_no: "1234567890", ..Default::default() }; @@ -604,7 +607,7 @@ mod tests { qr_code_data .add_optional_vendor_data( OPTIONAL_DEFAULT_STRING_TAG, - QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.to_string()), + QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.into()), ) .expect("Failed to add optional data"); diff --git a/matter/src/persist.rs b/matter/src/persist.rs new file mode 100644 index 0000000..4bc8e24 --- /dev/null +++ b/matter/src/persist.rs @@ -0,0 +1,229 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +pub trait Psm { + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error>; + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error>; + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error>; + fn get_kv_u64(&self, key: &str) -> Result; + + fn remove(&mut self, key: &str) -> Result<(), Error>; +} + +impl Psm for &mut T +where + T: Psm, +{ + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + (**self).set_kv_slice(key, val) + } + + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + (**self).get_kv_slice(key, buf) + } + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + (**self).set_kv_u64(key, val) + } + + fn get_kv_u64(&self, key: &str) -> Result { + (**self).get_kv_u64(key) + } + + fn remove(&mut self, key: &str) -> Result<(), Error> { + (**self).remove(key) + } +} + +#[cfg(feature = "nightly")] +pub mod asynch { + use crate::error::Error; + + use super::Psm; + + pub trait AsyncPsm { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error>; + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error>; + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error>; + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result; + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error>; + } + + impl AsyncPsm for &mut T + where + T: AsyncPsm, + { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { + (**self).set_kv_slice(key, val).await + } + + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error> { + (**self).get_kv_slice(key, buf).await + } + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { + (**self).set_kv_u64(key, val).await + } + + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { + (**self).get_kv_u64(key).await + } + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { + (**self).remove(key).await + } + } + + pub struct Asyncify(pub T); + + impl AsyncPsm for Asyncify + where + T: Psm, + { + async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { + self.0.set_kv_slice(key, val) + } + + async fn get_kv_slice<'a, 'b>( + &'a self, + key: &'a str, + buf: &'b mut [u8], + ) -> Result<&'b [u8], Error> { + self.0.get_kv_slice(key, buf) + } + + async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { + self.0.set_kv_u64(key, val) + } + + async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { + self.0.get_kv_u64(key) + } + + async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { + self.0.remove(key) + } + } +} + +#[cfg(feature = "std")] +pub mod std { + use std::fs::{self, DirBuilder, File}; + use std::io::{Read, Write}; + + use crate::error::Error; + + use super::Psm; + + pub struct FilePsm {} + + const PSM_DIR: &str = "/tmp/matter_psm"; + + macro_rules! psm_path { + ($key:ident) => { + format!("{}/{}", PSM_DIR, $key) + }; + } + + impl FilePsm { + pub fn new() -> Result { + let result = DirBuilder::new().create(PSM_DIR); + if let Err(e) = result { + if e.kind() != std::io::ErrorKind::AlreadyExists { + return Err(e.into()); + } + } + + Ok(Self {}) + } + + pub fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + let mut f = File::create(psm_path!(key))?; + f.write_all(val)?; + Ok(()) + } + + pub fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + let mut f = File::open(psm_path!(key))?; + let mut offset = 0; + + loop { + let len = f.read(&mut buf[offset..])?; + offset += len; + + if len == 0 { + break; + } + } + + Ok(&buf[..offset]) + } + + pub fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + let mut f = File::create(psm_path!(key))?; + f.write_all(&val.to_be_bytes())?; + Ok(()) + } + + pub fn get_kv_u64(&self, key: &str) -> Result { + let mut f = File::open(psm_path!(key))?; + let mut buf = [0; 8]; + f.read_exact(&mut buf)?; + Ok(u64::from_be_bytes(buf)) + } + + pub fn remove(&self, key: &str) -> Result<(), Error> { + fs::remove_file(psm_path!(key))?; + Ok(()) + } + } + + impl Psm for FilePsm { + fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { + FilePsm::set_kv_slice(self, key, val) + } + + fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { + FilePsm::get_kv_slice(self, key, buf) + } + + fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { + FilePsm::set_kv_u64(self, key, val) + } + + fn get_kv_u64(&self, key: &str) -> Result { + FilePsm::get_kv_u64(self, key) + } + + fn remove(&mut self, key: &str) -> Result<(), Error> { + FilePsm::remove(self, key) + } + } +} diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 58a5593..80802ed 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -15,27 +15,25 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; use log::{error, trace}; -use owning_ref::RwLockReadGuardRef; -use rand::prelude::*; use crate::{ cert::Cert, - crypto::{self, CryptoKeyPair, KeyPair, Sha256}, + crypto::{self, KeyPair, Sha256}, error::Error, - fabric::{Fabric, FabricMgr, FabricMgrInner}, + fabric::{Fabric, FabricMgr}, secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, transport::{ network::Address, - proto_demux::{ProtoCtx, ResponseRequired}, + proto_ctx::ProtoCtx, queue::{Msg, WorkQ}, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, - utils::writebuf::WriteBuf, + utils::{rand::Rand, writebuf::WriteBuf}, }; #[derive(PartialEq)] @@ -54,6 +52,7 @@ pub struct CaseSession { peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], local_fabric_idx: usize, } + impl CaseSession { pub fn new(peer_sessid: u16, local_sessid: u16) -> Result { Ok(Self { @@ -69,40 +68,43 @@ impl CaseSession { } } -pub struct Case { - fabric_mgr: Arc, +pub struct Case<'a> { + fabric_mgr: &'a RefCell, + rand: Rand, } -impl Case { - pub fn new(fabric_mgr: Arc) -> Self { - Self { fabric_mgr } +impl<'a> Case<'a> { + pub fn new(fabric_mgr: &'a RefCell, rand: Rand) -> Self { + Self { fabric_mgr, rand } } - pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { let mut case_session = ctx .exch_ctx .exch - .take_data_boxed::() + .take_case_session::() .ok_or(Error::InvalidState)?; if case_session.state != State::Sigma1Rx { return Err(Error::Invalid); } case_session.state = State::Sigma3Rx; - let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric_mgr = self.fabric_mgr.borrow(); + + let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } // Safe to unwrap here - let fabric = fabric.as_ref().as_ref().unwrap(); + let fabric = fabric.unwrap(); - let root = get_root_node_struct(ctx.rx.as_borrow_slice())?; + let root = get_root_node_struct(ctx.rx.as_slice())?; let encrypted = root.find_tag(1)?.slice()?; let mut decrypted: [u8; 800] = [0; 800]; @@ -126,13 +128,9 @@ impl Case { } if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { error!("Certificate Chain doesn't match: {}", e); - common::create_sc_status_report( - &mut ctx.tx, - common::SCStatusCodes::InvalidParameter, - None, - )?; + common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } if Case::validate_sigma3_sign( @@ -145,19 +143,15 @@ impl Case { .is_err() { error!("Sigma3 Signature doesn't match"); - common::create_sc_status_report( - &mut ctx.tx, - common::SCStatusCodes::InvalidParameter, - None, - )?; + common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(ctx.rx.as_borrow_slice())?; + case_session.tt_hash.update(ctx.rx.as_slice())?; let clone_data = Case::get_session_clone_data( fabric.ipk.op_key(), fabric.get_node_id(), @@ -169,40 +163,36 @@ impl Case { // Queue a transport mgr request to add a new session WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; - common::create_sc_status_report( - &mut ctx.tx, - SCStatusCodes::SessionEstablishmentSuccess, - None, - )?; - ctx.exch_ctx.exch.clear_data_boxed(); + common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; + ctx.exch_ctx.exch.clear_data(); ctx.exch_ctx.exch.close(); - - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); - let rx_buf = ctx.rx.as_borrow_slice(); + let rx_buf = ctx.rx.as_slice(); let root = get_root_node_struct(rx_buf)?; let r = Sigma1Req::from_tlv(&root)?; let local_fabric_idx = self .fabric_mgr + .borrow_mut() .match_dest_id(r.initiator_random.0, r.dest_id.0); if local_fabric_idx.is_err() { error!("Fabric Index mismatch"); common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); - let mut case_session = Box::new(CaseSession::new(r.initiator_sessid, local_sessid)?); + let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?; case_session.tt_hash.update(rx_buf)?; case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { @@ -228,7 +218,7 @@ impl Case { // println!("Derived secret: {:x?} len: {}", secret, len); let mut our_random: [u8; 32] = [0; 32]; - rand::thread_rng().fill_bytes(&mut our_random); + (self.rand)(&mut our_random); // Derive the Encrypted Part const MAX_ENCRYPTED_SIZE: usize = 800; @@ -236,19 +226,21 @@ impl Case { let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; let encrypted_len = { let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; - let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric_mgr = self.fabric_mgr.borrow(); + + let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; if fabric.is_none() { common::create_sc_status_report( - &mut ctx.tx, + ctx.tx, common::SCStatusCodes::NoSharedTrustRoots, None, )?; ctx.exch_ctx.exch.close(); - return Ok(ResponseRequired::Yes); + return Ok(true); } let sign_len = Case::get_sigma2_sign( - &fabric, + fabric.unwrap(), &case_session.our_pub_key, &case_session.peer_pub_key, &mut signature, @@ -256,7 +248,8 @@ impl Case { let signature = &signature[..sign_len]; Case::get_sigma2_encryption( - &fabric, + fabric.unwrap(), + self.rand, &our_random, &mut case_session, signature, @@ -273,9 +266,9 @@ impl Case { tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str16(TagType::Context(4), encrypted)?; tw.end_container()?; - case_session.tt_hash.update(ctx.tx.as_borrow_slice())?; - ctx.exch_ctx.exch.set_data_boxed(case_session); - Ok(ResponseRequired::Yes) + case_session.tt_hash.update(ctx.tx.as_mut_slice())?; + ctx.exch_ctx.exch.set_case_session(case_session); + Ok(true) } fn get_session_clone_data( @@ -322,8 +315,8 @@ impl Case { case_session: &CaseSession, ) -> Result<(), Error> { const MAX_TBS_SIZE: usize = 800; - let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; - let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); + let mut buf = [0; MAX_TBS_SIZE]; + let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16(TagType::Context(1), initiator_noc)?; @@ -335,7 +328,7 @@ impl Case { tw.end_container()?; let key = KeyPair::new_from_public(initiator_noc_cert.get_pubkey())?; - key.verify_msg(write_buf.as_slice(), sign)?; + key.verify_msg(write_buf.into_slice(), sign)?; Ok(()) } @@ -372,12 +365,12 @@ impl Case { if key.len() < 48 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); let tt = tt.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) @@ -420,14 +413,14 @@ impl Case { if key.len() < 16 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); let tt = tt.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) @@ -447,16 +440,16 @@ impl Case { if key.len() < 16 { return Err(Error::NoSpace); } - let mut salt = Vec::::with_capacity(256); - salt.extend_from_slice(ipk); - salt.extend_from_slice(our_random); - salt.extend_from_slice(&case_session.our_pub_key); + let mut salt = heapless::Vec::::new(); + salt.extend_from_slice(ipk).unwrap(); + salt.extend_from_slice(our_random).unwrap(); + salt.extend_from_slice(&case_session.our_pub_key).unwrap(); let tt = case_session.tt_hash.clone(); let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; tt.finish(&mut tt_hash)?; - salt.extend_from_slice(&tt_hash); + salt.extend_from_slice(&tt_hash).unwrap(); // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) @@ -467,17 +460,15 @@ impl Case { } fn get_sigma2_encryption( - fabric: &RwLockReadGuardRef>, + fabric: &Fabric, + rand: Rand, our_random: &[u8], case_session: &mut CaseSession, signature: &[u8], out: &mut [u8], ) -> Result { let mut resumption_id: [u8; 16] = [0; 16]; - rand::thread_rng().fill_bytes(&mut resumption_id); - - // We are guaranteed this unwrap will work - let fabric = fabric.as_ref().as_ref().unwrap(); + rand(&mut resumption_id); let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES]; Case::get_sigma2_key( @@ -487,7 +478,7 @@ impl Case { &mut sigma2_key, )?; - let mut write_buf = WriteBuf::new(out, out.len()); + let mut write_buf = WriteBuf::new(out); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; @@ -517,20 +508,19 @@ impl Case { cipher_text, cipher_text.len() - TAG_LEN, )?; - Ok(write_buf.as_slice().len()) + Ok(write_buf.into_slice().len()) } fn get_sigma2_sign( - fabric: &RwLockReadGuardRef>, + fabric: &Fabric, our_pub_key: &[u8], peer_pub_key: &[u8], signature: &mut [u8], ) -> Result { // We are guaranteed this unwrap will work - let fabric = fabric.as_ref().as_ref().unwrap(); const MAX_TBS_SIZE: usize = 800; - let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; - let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); + let mut buf = [0; MAX_TBS_SIZE]; + let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; @@ -541,7 +531,7 @@ impl Case { tw.str8(TagType::Context(4), peer_pub_key)?; tw.end_container()?; //println!("TBS is {:x?}", write_buf.as_borrow_slice()); - fabric.sign_msg(write_buf.as_slice(), signature) + fabric.sign_msg(write_buf.into_slice(), signature) } } diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 511bb5c..7049ba3 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -15,23 +15,14 @@ * limitations under the License. */ -use boxslab::Slab; -use log::info; use num_derive::FromPrimitive; -use crate::{ - error::Error, - transport::{ - exchange::Exchange, - packet::{Packet, PacketPool}, - session::SessionHandle, - }, -}; +use crate::{error::Error, transport::packet::Packet}; use super::status_report::{create_status_report, GeneralCode}; /* Interaction Model ID as per the Matter Spec */ -pub const PROTO_ID_SECURE_CHANNEL: usize = 0x00; +pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00; #[derive(FromPrimitive, Debug)] pub enum OpCode { @@ -88,14 +79,15 @@ pub fn create_sc_status_report( } pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.unset_reliable(); } -pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { - info!("Sending standalone ACK"); - let mut ack_packet = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; - create_mrp_standalone_ack(&mut ack_packet); - exch.send(ack_packet, sess) -} +// TODO +// pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { +// info!("Sending standalone ACK"); +// let mut ack_packet = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; +// create_mrp_standalone_ack(&mut ack_packet); +// exch.send(ack_packet, sess) +// } diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 9f7d16b..5ca1804 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use std::sync::Arc; +use core::cell::RefCell; use crate::{ - error::*, - fabric::FabricMgr, - secure_channel::common::*, - tlv, - transport::proto_demux::{self, ProtoCtx, ResponseRequired}, + error::*, fabric::FabricMgr, mdns::MdnsMgr, secure_channel::common::*, tlv, + transport::proto_ctx::ProtoCtx, utils::rand::Rand, }; use log::{error, info}; use num; @@ -32,48 +29,54 @@ use super::{case::Case, pake::PaseMgr}; /* Handle messages related to the Secure Channel */ -pub struct SecureChannel { - case: Case, - pase: PaseMgr, +pub struct SecureChannel<'a> { + case: Case<'a>, + pase: &'a RefCell, + mdns: &'a RefCell>, } -impl SecureChannel { - pub fn new(pase: PaseMgr, fabric_mgr: Arc) -> SecureChannel { +impl<'a> SecureChannel<'a> { + pub fn new( + pase: &'a RefCell, + fabric_mgr: &'a RefCell, + mdns: &'a RefCell>, + rand: Rand, + ) -> Self { SecureChannel { + case: Case::new(fabric_mgr, rand), pase, - case: Case::new(fabric_mgr), + mdns, } } -} -impl proto_demux::HandleProto for SecureChannel { - fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { let proto_opcode: OpCode = num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; - ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); - tlv::print_tlv_list(ctx.rx.as_borrow_slice()); - let result = match proto_opcode { - OpCode::MRPStandAloneAck => Ok(ResponseRequired::No), - OpCode::PBKDFParamRequest => self.pase.pbkdfparamreq_handler(ctx), - OpCode::PASEPake1 => self.pase.pasepake1_handler(ctx), - OpCode::PASEPake3 => self.pase.pasepake3_handler(ctx), + tlv::print_tlv_list(ctx.rx.as_slice()); + let reply = match proto_opcode { + OpCode::MRPStandAloneAck => Ok(true), + OpCode::PBKDFParamRequest => self.pase.borrow_mut().pbkdfparamreq_handler(ctx), + OpCode::PASEPake1 => self.pase.borrow_mut().pasepake1_handler(ctx), + OpCode::PASEPake3 => self + .pase + .borrow_mut() + .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); Err(Error::InvalidOpcode) } - }; - if result == Ok(ResponseRequired::Yes) { - info!("Sending response"); - tlv::print_tlv_list(ctx.tx.as_borrow_slice()); - } - result - } + }?; - fn get_proto_id(&self) -> usize { - PROTO_ID_SECURE_CHANNEL + if reply { + info!("Sending response"); + tlv::print_tlv_list(ctx.tx.as_mut_slice()); + } + + Ok(reply) } } diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index f9481ba..45d8359 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -15,40 +15,15 @@ * limitations under the License. */ -use crate::error::Error; - -// This trait allows us to switch between crypto providers like OpenSSL and mbedTLS for Spake2 -// Currently this is only validate for a verifier(responder) - -// A verifier will typically do: -// Step 1: w0 and L -// set_w0_from_w0s -// set_L -// Step 2: get_pB -// Step 3: get_TT_as_verifier(pA) -// Step 4: Computation of cA and cB happens outside since it doesn't use either BigNum or EcPoint -pub trait CryptoSpake2 { - fn new() -> Result - where - Self: Sized; - - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error>; - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>; - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error>; - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error>; - - #[allow(non_snake_case)] - fn set_L(&mut self, l: &[u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error>; - #[allow(non_snake_case)] - fn get_TT_as_verifier( - &mut self, - context: &[u8], - pA: &[u8], - pB: &[u8], - out: &mut [u8], - ) -> Result<(), Error>; -} +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls" +)))] +pub use super::crypto_dummy::CryptoSpake2; +#[cfg(feature = "crypto_esp_mbedtls")] +pub use super::crypto_esp_mbedtls::CryptoSpake2; +#[cfg(feature = "crypto_mbedtls")] +pub use super::crypto_mbedtls::CryptoSpake2; +#[cfg(feature = "crypto_openssl")] +pub use super::crypto_openssl::CryptoSpake2; diff --git a/matter/src/secure_channel/crypto_dummy.rs b/matter/src/secure_channel/crypto_dummy.rs new file mode 100644 index 0000000..11ec852 --- /dev/null +++ b/matter/src/secure_channel/crypto_dummy.rs @@ -0,0 +1,73 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +#[allow(non_snake_case)] + +pub struct CryptoSpake2 {} + +impl CryptoSpake2 { + #[allow(non_snake_case)] + pub fn new() -> Result { + Ok(Self {}) + } + + // Computes w0 from w0s respectively + pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + #[allow(dead_code)] + pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn get_pB(&mut self, _pB: &mut [u8]) -> Result<(), Error> { + Err(Error::Invalid) + } + + #[allow(non_snake_case)] + pub fn get_TT_as_verifier( + &mut self, + _context: &[u8], + _pA: &[u8], + _pB: &[u8], + _out: &mut [u8], + ) -> Result<(), Error> { + Err(Error::Invalid) + } +} diff --git a/matter/src/secure_channel/crypto_esp_mbedtls.rs b/matter/src/secure_channel/crypto_esp_mbedtls.rs index 632be2c..316276b 100644 --- a/matter/src/secure_channel/crypto_esp_mbedtls.rs +++ b/matter/src/secure_channel/crypto_esp_mbedtls.rs @@ -17,8 +17,6 @@ use crate::error::Error; -use super::crypto::CryptoSpake2; - const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, @@ -36,16 +34,16 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoEspMbedTls {} +pub struct CryptoSpake2 {} -impl CryptoSpake2 for CryptoEspMbedTls { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { - Ok(CryptoEspMbedTls {}) + pub fn new() -> Result { + Ok(Self {}) } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -53,7 +51,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -61,17 +59,17 @@ impl CryptoSpake2 for CryptoEspMbedTls { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -79,7 +77,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -90,7 +88,7 @@ impl CryptoSpake2 for CryptoEspMbedTls { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -101,13 +99,10 @@ impl CryptoSpake2 for CryptoEspMbedTls { } } -impl CryptoEspMbedTls {} - #[cfg(test)] mod tests { - use super::CryptoEspMbedTls; - use crate::secure_channel::crypto::CryptoSpake2; + use super::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use openssl::bn::BigNum; use openssl::ec::{EcPoint, PointConversionForm}; @@ -116,13 +111,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = - CryptoEspMbedTls::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.X, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -136,12 +130,11 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = - CryptoEspMbedTls::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.Y, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -155,12 +148,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoEspMbedTls::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -191,12 +184,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoEspMbedTls::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoEspMbedTls::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index 7ac4c5a..27c9fc6 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use std::{ - ops::{Mul, Sub}, - sync::Arc, -}; +use alloc::sync::Arc; +use core::ops::{Mul, Sub}; use crate::error::Error; -use super::crypto::CryptoSpake2; use byteorder::{ByteOrder, LittleEndian}; use log::error; use mbedtls::{ @@ -33,6 +30,8 @@ use mbedtls::{ rng::{CtrDrbg, OsEntropy}, }; +extern crate alloc; + const MATTER_M_BIN: [u8; 65] = [ 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, @@ -50,7 +49,7 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoMbedTLS { +pub struct CryptoSpake2 { group: EcGroup, order: Mpi, xy: Mpi, @@ -62,15 +61,15 @@ pub struct CryptoMbedTLS { pB: EcPoint, } -impl CryptoSpake2 for CryptoMbedTLS { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { + pub fn new() -> Result { let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?; let order = group.order()?; let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?; let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?; - Ok(CryptoMbedTLS { + Ok(Self { group, order, xy: Mpi::new(0)?, @@ -84,7 +83,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -94,7 +93,7 @@ impl CryptoSpake2 for CryptoMbedTLS { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -104,24 +103,25 @@ impl CryptoSpake2 for CryptoMbedTLS { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { self.w0 = Mpi::from_binary(w0)?; Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { self.w1 = Mpi::from_binary(w1)?; Ok(()) } - fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + #[allow(non_snake_case)] + pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = EcPoint::from_binary(&self.group, l)?; Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -132,7 +132,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -157,7 +157,7 @@ impl CryptoSpake2 for CryptoMbedTLS { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -166,21 +166,21 @@ impl CryptoSpake2 for CryptoMbedTLS { ) -> Result<(), Error> { let mut TT = Md::new(mbedtls::hash::Type::Sha256)?; // context - CryptoMbedTLS::add_to_tt(&mut TT, context)?; + Self::add_to_tt(&mut TT, context)?; // 2 empty identifiers - CryptoMbedTLS::add_to_tt(&mut TT, &[])?; - CryptoMbedTLS::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; // M - CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_M_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_M_BIN)?; // N - CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_N_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_N_BIN)?; // X = pA - CryptoMbedTLS::add_to_tt(&mut TT, pA)?; + Self::add_to_tt(&mut TT, pA)?; // Y = pB - CryptoMbedTLS::add_to_tt(&mut TT, pB)?; + Self::add_to_tt(&mut TT, pB)?; let X = EcPoint::from_binary(&self.group, pA)?; - let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( + let (Z, V) = Self::get_ZV_as_verifier( &self.w0, &self.L, &mut self.M, @@ -193,24 +193,22 @@ impl CryptoSpake2 for CryptoMbedTLS { // Z let tmp = Z.to_binary(&self.group, false)?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // V let tmp = V.to_binary(&self.group, false)?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // w0 let tmp = self.w0.to_binary()?; let tmp = tmp.as_slice(); - CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; TT.finish(out)?; Ok(()) } -} -impl CryptoMbedTLS { fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> { let mut len_buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut len_buf, buf.len() as u64); @@ -247,7 +245,7 @@ impl CryptoMbedTLS { let mut tmp = x.mul(w0)?; tmp = tmp.modulo(order)?; - let inverted_N = CryptoMbedTLS::invert(group, N)?; + let inverted_N = Self::invert(group, N)?; let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?; // Cofactor for P256 is 1, so that is a No-Op @@ -283,7 +281,7 @@ impl CryptoMbedTLS { let mut tmp = y.mul(w0)?; tmp = tmp.modulo(order)?; - let inverted_M = CryptoMbedTLS::invert(group, M)?; + let inverted_M = Self::invert(group, M)?; let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?; // Cofactor for P256 is 1, so that is a No-Op @@ -302,8 +300,7 @@ impl CryptoMbedTLS { #[cfg(test)] mod tests { - use super::CryptoMbedTLS; - use crate::secure_channel::crypto::CryptoSpake2; + use super::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use mbedtls::bignum::Mpi; use mbedtls::ecp::EcPoint; @@ -312,7 +309,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = Mpi::from_binary(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator().unwrap(); @@ -326,7 +323,7 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = Mpi::from_binary(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator().unwrap(); @@ -339,12 +336,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = Mpi::from_binary(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap(); - let (Z, V) = CryptoMbedTLS::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -364,12 +361,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoMbedTLS::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = Mpi::from_binary(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_binary(&c.group, &t.X).unwrap(); let L = EcPoint::from_binary(&c.group, &t.L).unwrap(); - let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index 84d6793..631cb6b 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -17,7 +17,6 @@ use crate::error::Error; -use super::crypto::CryptoSpake2; use byteorder::{ByteOrder, LittleEndian}; use log::error; use openssl::{ @@ -44,7 +43,7 @@ const MATTER_N_BIN: [u8; 65] = [ #[allow(non_snake_case)] -pub struct CryptoOpenSSL { +pub struct CryptoSpake2 { group: EcGroup, bn_ctx: BigNumContext, // Stores the randomly generated x or y depending upon who we are @@ -58,9 +57,9 @@ pub struct CryptoOpenSSL { order: BigNum, } -impl CryptoSpake2 for CryptoOpenSSL { +impl CryptoSpake2 { #[allow(non_snake_case)] - fn new() -> Result { + pub fn new() -> Result { let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let mut bn_ctx = BigNumContext::new()?; let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?; @@ -70,7 +69,7 @@ impl CryptoSpake2 for CryptoOpenSSL { let mut order = BigNum::new()?; group.as_ref().order(&mut order, &mut bn_ctx)?; - Ok(CryptoOpenSSL { + Ok(Self { group, bn_ctx, xy: BigNum::new()?, @@ -85,7 +84,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } // Computes w0 from w0s respectively - fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { + pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w0 = w0s mod p // where p is the order of the curve @@ -96,7 +95,7 @@ impl CryptoSpake2 for CryptoOpenSSL { Ok(()) } - fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter Spec, // w1 = w1s mod p // where p is the order of the curve @@ -107,24 +106,24 @@ impl CryptoSpake2 for CryptoOpenSSL { Ok(()) } - fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { + pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { self.w0 = BigNum::from_slice(w0)?; Ok(()) } - fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { + pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { self.w1 = BigNum::from_slice(w1)?; Ok(()) } - fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; Ok(()) } #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { + pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve @@ -135,7 +134,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } #[allow(non_snake_case)] - fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { + pub fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // for y // - select random y between 0 to p @@ -143,7 +142,7 @@ impl CryptoSpake2 for CryptoOpenSSL { // - pB = Y self.order.rand_range(&mut self.xy)?; let P = self.group.generator(); - self.pB = CryptoOpenSSL::do_add_mul( + self.pB = Self::do_add_mul( P, &self.xy, &self.N, @@ -166,7 +165,7 @@ impl CryptoSpake2 for CryptoOpenSSL { } #[allow(non_snake_case)] - fn get_TT_as_verifier( + pub fn get_TT_as_verifier( &mut self, context: &[u8], pA: &[u8], @@ -175,21 +174,21 @@ impl CryptoSpake2 for CryptoOpenSSL { ) -> Result<(), Error> { let mut TT = Hasher::new(MessageDigest::sha256())?; // context - CryptoOpenSSL::add_to_tt(&mut TT, context)?; + Self::add_to_tt(&mut TT, context)?; // 2 empty identifiers - CryptoOpenSSL::add_to_tt(&mut TT, &[])?; - CryptoOpenSSL::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; + Self::add_to_tt(&mut TT, &[])?; // M - CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_M_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_M_BIN)?; // N - CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_N_BIN)?; + Self::add_to_tt(&mut TT, &MATTER_N_BIN)?; // X = pA - CryptoOpenSSL::add_to_tt(&mut TT, pA)?; + Self::add_to_tt(&mut TT, pA)?; // Y = pB - CryptoOpenSSL::add_to_tt(&mut TT, pB)?; + Self::add_to_tt(&mut TT, pB)?; let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?; - let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( + let (Z, V) = Self::get_ZV_as_verifier( &self.w0, &self.L, &mut self.M, @@ -207,7 +206,7 @@ impl CryptoSpake2 for CryptoOpenSSL { &mut self.bn_ctx, )?; let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // V let tmp = V.to_bytes( @@ -216,20 +215,18 @@ impl CryptoSpake2 for CryptoOpenSSL { &mut self.bn_ctx, )?; let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; // w0 let tmp = self.w0.to_vec(); let tmp = tmp.as_slice(); - CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; + Self::add_to_tt(&mut TT, tmp)?; let h = TT.finish()?; TT_hash.copy_from_slice(h.as_ref()); Ok(()) } -} -impl CryptoOpenSSL { fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> { let mut len_buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut len_buf, buf.len() as u64); @@ -286,11 +283,11 @@ impl CryptoOpenSSL { let mut tmp = BigNum::new()?; tmp.mod_mul(x, w0, order, bn_ctx)?; N.invert(group, bn_ctx)?; - let Z = CryptoOpenSSL::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; + let Z = Self::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; // Cofactor for P256 is 1, so that is a No-Op tmp.mod_mul(w1, w0, order, bn_ctx)?; - let V = CryptoOpenSSL::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; + let V = Self::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; Ok((Z, V)) } @@ -321,7 +318,7 @@ impl CryptoOpenSSL { let mut tmp = BigNum::new()?; tmp.mod_mul(y, w0, order, bn_ctx)?; M.invert(group, bn_ctx)?; - let Z = CryptoOpenSSL::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; + let Z = Self::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; // Cofactor for P256 is 1, so that is a No-Op let mut V = EcPoint::new(group)?; @@ -333,7 +330,7 @@ impl CryptoOpenSSL { #[cfg(test)] mod tests { - use super::CryptoOpenSSL; + use super::CryptoSpake2; use crate::secure_channel::crypto::CryptoSpake2; use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use openssl::bn::BigNum; @@ -343,12 +340,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_X() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = CryptoOpenSSL::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.X, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -362,11 +359,11 @@ mod tests { #[allow(non_snake_case)] fn test_get_Y() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let P = c.group.generator(); - let r = CryptoOpenSSL::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); + let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); assert_eq!( t.Y, r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) @@ -380,12 +377,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_prover() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let x = BigNum::from_slice(&t.x).unwrap(); c.set_w0(&t.w0).unwrap(); c.set_w1(&t.w1).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoOpenSSL::get_ZV_as_prover( + let (Z, V) = CryptoSpake2::get_ZV_as_prover( &c.w0, &c.w1, &mut c.N, @@ -416,12 +413,12 @@ mod tests { #[allow(non_snake_case)] fn test_get_ZV_as_verifier() { for t in RFC_T { - let mut c = CryptoOpenSSL::new().unwrap(); + let mut c = CryptoSpake2::new().unwrap(); let y = BigNum::from_slice(&t.y).unwrap(); c.set_w0(&t.w0).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); - let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( + let (Z, V) = CryptoSpake2::get_ZV_as_verifier( &c.w0, &L, &mut c.M, diff --git a/matter/src/secure_channel/mod.rs b/matter/src/secure_channel/mod.rs index 9328b25..15417b3 100644 --- a/matter/src/secure_channel/mod.rs +++ b/matter/src/secure_channel/mod.rs @@ -17,10 +17,17 @@ pub mod case; pub mod common; +#[cfg(not(any( + feature = "crypto_openssl", + feature = "crypto_mbedtls", + feature = "crypto_esp_mbedtls", + feature = "crypto_rustcrypto" +)))] +mod crypto_dummy; #[cfg(feature = "crypto_esp_mbedtls")] -pub mod crypto_esp_mbedtls; +mod crypto_esp_mbedtls; #[cfg(feature = "crypto_mbedtls")] -pub mod crypto_mbedtls; +mod crypto_mbedtls; #[cfg(feature = "crypto_openssl")] pub mod crypto_openssl; #[cfg(feature = "crypto_rustcrypto")] diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 27aaeb0..ce05fb6 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -15,10 +15,7 @@ * limitations under the License. */ -use std::{ - sync::{Arc, Mutex}, - time::{Duration, SystemTime}, -}; +use core::{fmt::Write, time::Duration}; use super::{ common::{create_sc_status_report, SCStatusCodes}, @@ -27,97 +24,115 @@ use super::{ use crate::{ crypto, error::Error, - mdns::{self, Mdns}, + mdns::{MdnsMgr, ServiceMode}, secure_channel::common::OpCode, - sys::SysMdnsService, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, - proto_demux::{ProtoCtx, ResponseRequired}, + proto_ctx::ProtoCtx, queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, + utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -use rand::prelude::*; +#[allow(clippy::large_enum_variant)] enum PaseMgrState { - Enabled(PAKE, SysMdnsService), + Enabled(Pake, heapless::String<16>, u16), Disabled, } -pub struct PaseMgrInternal { +// Could this lock be avoided? +pub struct PaseMgr { state: PaseMgrState, + epoch: Epoch, + rand: Rand, } -#[derive(Clone)] -// Could this lock be avoided? -pub struct PaseMgr(Arc>); - impl PaseMgr { - pub fn new() -> Self { - Self(Arc::new(Mutex::new(PaseMgrInternal { + pub fn new(epoch: Epoch, rand: Rand) -> Self { + Self { state: PaseMgrState::Disabled, - }))) + epoch, + rand, + } } pub fn enable_pase_session( &mut self, verifier: VerifierData, discriminator: u16, + mdns: &mut MdnsMgr, ) -> Result<(), Error> { - let mut s = self.0.lock().unwrap(); - let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); - let name = format!("{:016X}", name); - let mdns = Mdns::get()? - .publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; - s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); + let mut buf = [0; 8]; + (self.rand)(&mut buf); + let num = u64::from_be_bytes(buf); + + let mut mdns_service_name = heapless::String::<16>::new(); + write!(&mut mdns_service_name, "{:016X}", num).unwrap(); + + mdns.publish_service( + &mdns_service_name, + ServiceMode::Commissionable(discriminator), + )?; + self.state = PaseMgrState::Enabled( + Pake::new(verifier, self.epoch, self.rand), + mdns_service_name, + discriminator, + ); + Ok(()) } - pub fn disable_pase_session(&mut self) { - let mut s = self.0.lock().unwrap(); - s.state = PaseMgrState::Disabled; + pub fn disable_pase_session(&mut self, mdns: &mut MdnsMgr) -> Result<(), Error> { + if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { + mdns.unpublish_service( + mdns_service_name, + ServiceMode::Commissionable(*discriminator), + )?; + } + + self.state = PaseMgrState::Disabled; + + Ok(()) } /// If the PASE Session is enabled, execute the closure, /// if not enabled, generate SC Status Report fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> where - F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>, + F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<(), Error>, { - let mut s = self.0.lock().unwrap(); - if let PaseMgrState::Enabled(pake, _) = &mut s.state { + if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { f(pake, ctx) } else { error!("PASE Not enabled"); - create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None) + create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None) } } - pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; - Ok(ResponseRequired::Yes) + Ok(true) } - pub fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn pasepake3_handler( + &mut self, + ctx: &mut ProtoCtx, + mdns: &mut MdnsMgr, + ) -> Result { self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; - self.disable_pase_session(); - Ok(ResponseRequired::Yes) - } -} - -impl Default for PaseMgr { - fn default() -> Self { - Self::new() + self.disable_pase_session(mdns)?; + Ok(true) } } @@ -131,30 +146,31 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60); const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; struct SessionData { - start_time: SystemTime, + start_time: Duration, exch_id: u16, peer_addr: Address, - spake2p: Box, + spake2p: Spake2P, } impl SessionData { - fn is_sess_expired(&self) -> Result { - if SystemTime::now().duration_since(self.start_time)? > PASE_DISCARD_TIMEOUT_SECS { - Ok(true) - } else { - Ok(false) - } + fn is_sess_expired(&self, epoch: Epoch) -> Result { + Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS) } } +#[allow(clippy::large_enum_variant)] enum PakeState { Idle, InProgress(SessionData), } impl PakeState { + const fn new() -> Self { + Self::Idle + } + fn take(&mut self) -> Result { - let new = std::mem::replace(self, PakeState::Idle); + let new = core::mem::replace(self, PakeState::Idle); if let PakeState::InProgress(s) = new { Ok(s) } else { @@ -163,7 +179,7 @@ impl PakeState { } fn is_idle(&self) -> bool { - std::mem::discriminant(self) == std::mem::discriminant(&PakeState::Idle) + core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle) } fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { @@ -175,9 +191,9 @@ impl PakeState { } } - fn make_in_progress(&mut self, spake2p: Box, exch_ctx: &ExchangeCtx) { + fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) { *self = PakeState::InProgress(SessionData { - start_time: SystemTime::now(), + start_time: epoch(), spake2p, exch_id: exch_ctx.exch.get_id(), peer_addr: exch_ctx.sess.get_peer_addr(), @@ -191,21 +207,25 @@ impl PakeState { impl Default for PakeState { fn default() -> Self { - Self::Idle + Self::new() } } -pub struct PAKE { - pub verifier: VerifierData, +struct Pake { + verifier: VerifierData, state: PakeState, + epoch: Epoch, + rand: Rand, } -impl PAKE { - pub fn new(verifier: VerifierData) -> Self { +impl Pake { + pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self { // TODO: Can any PBKDF2 calculation be pre-computed here - PAKE { + Self { verifier, - state: Default::default(), + state: PakeState::new(), + epoch, + rand, } } @@ -213,14 +233,14 @@ impl PAKE { pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; - let cA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; - let (status_code, Ke) = sd.spake2p.handle_cA(cA); + let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; + let (status_code, ke) = sd.spake2p.handle_cA(cA); if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys - let Ke = Ke.ok_or(Error::Invalid)?; + let ke = ke.ok_or(Error::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; - crypto::hkdf_sha256(&[], Ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) + crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) .map_err(|_x| Error::NoSpace)?; // Create a session @@ -245,7 +265,7 @@ impl PAKE { WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; } - create_sc_status_report(&mut ctx.tx, status_code, None)?; + create_sc_status_report(ctx.tx, status_code, None)?; ctx.exch_ctx.exch.close(); Ok(()) } @@ -254,7 +274,7 @@ impl PAKE { pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; - let pA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; + let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; sd.spake2p.start_verifier(&self.verifier)?; @@ -275,18 +295,18 @@ impl PAKE { pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { if !self.state.is_idle() { let sd = self.state.take()?; - if sd.is_sess_expired()? { + if sd.is_sess_expired(self.epoch)? { info!("Previous session expired, clearing it"); self.state = PakeState::Idle; } else { info!("Previous session in-progress, denying new request"); // little-endian timeout (here we've hardcoded 500ms) - create_sc_status_report(&mut ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; + create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; return Ok(()); } } - let root = tlv::get_root_node(ctx.rx.as_borrow_slice())?; + let root = tlv::get_root_node(ctx.rx.as_slice())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); @@ -294,11 +314,11 @@ impl PAKE { } let mut our_random: [u8; 32] = [0; 32]; - rand::thread_rng().fill_bytes(&mut our_random); + (self.rand)(&mut our_random); let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; - let mut spake2p = Box::new(Spake2P::new()); + let mut spake2p = Spake2P::new(); spake2p.set_app_data(spake2p_data); // Generate response @@ -318,8 +338,9 @@ impl PAKE { } resp.to_tlv(&mut tw, TagType::Anonymous)?; - spake2p.set_context(ctx.rx.as_borrow_slice(), ctx.tx.as_borrow_slice())?; - self.state.make_in_progress(spake2p, &ctx.exch_ctx); + spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; + self.state + .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); Ok(()) } diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 335d465..ba948f5 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -18,10 +18,10 @@ use crate::{ crypto::{self, HmacSha256}, sys, + utils::rand::Rand, }; use byteorder::{ByteOrder, LittleEndian}; use log::error; -use rand::prelude::*; use subtle::ConstantTimeEq; use crate::{ @@ -29,18 +29,6 @@ use crate::{ error::Error, }; -#[cfg(feature = "crypto_openssl")] -use super::crypto_openssl::CryptoOpenSSL; - -#[cfg(feature = "crypto_mbedtls")] -use super::crypto_mbedtls::CryptoMbedTLS; - -#[cfg(feature = "crypto_esp_mbedtls")] -use super::crypto_esp_mbedtls::CryptoEspMbedTls; - -#[cfg(feature = "crypto_rustcrypto")] -use super::crypto_rustcrypto::CryptoRustCrypto; - use super::{common::SCStatusCodes, crypto::CryptoSpake2}; // This file handle Spake2+ specific instructions. In itself, this file is @@ -74,7 +62,7 @@ pub struct Spake2P { context: Option, Ke: [u8; 16], cA: [u8; 32], - crypto_spake2: Option>, + crypto_spake2: Option, app_data: u32, } @@ -87,24 +75,8 @@ const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1; const MAX_SALT_SIZE_BYTES: usize = 32; const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; -#[cfg(feature = "crypto_openssl")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoOpenSSL::new()?)) -} - -#[cfg(feature = "crypto_mbedtls")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoMbedTLS::new()?)) -} - -#[cfg(feature = "crypto_esp_mbedtls")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoEspMbedTls::new()?)) -} - -#[cfg(feature = "crypto_rustcrypto")] -fn crypto_spake2_new() -> Result, Error> { - Ok(Box::new(CryptoRustCrypto::new()?)) +fn crypto_spake2_new() -> Result { + CryptoSpake2::new() } impl Default for Spake2P { @@ -129,13 +101,13 @@ pub enum VerifierOption { } impl VerifierData { - pub fn new_with_pw(pw: u32) -> Self { + pub fn new_with_pw(pw: u32, rand: Rand) -> Self { let mut s = Self { salt: [0; MAX_SALT_SIZE_BYTES], count: sys::SPAKE2_ITERATION_COUNT, data: VerifierOption::Password(pw), }; - rand::thread_rng().fill_bytes(&mut s.salt); + rand(&mut s.salt); s } @@ -158,7 +130,7 @@ impl VerifierData { } impl Spake2P { - pub fn new() -> Self { + pub const fn new() -> Self { Spake2P { mode: Spake2Mode::Unknown, context: None, @@ -198,7 +170,7 @@ impl Spake2P { match verifier.data { VerifierOption::Password(pw) => { // Derive w0 and L from the password - let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; + let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s); let w0s_len = w0w1s.len() / 2; @@ -317,7 +289,7 @@ mod tests { 0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82, 0x36, 0x0, ]; - let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; + let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s); assert_eq!( w0w1s, diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 050cd5b..477bcfa 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -46,7 +46,7 @@ pub fn create_status_report( proto_code: u16, proto_data: Option<&[u8]>, ) -> Result<(), Error> { - proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); + proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::StatusReport as u8); let wb = proto_tx.get_writebuf()?; wb.le_u16(general_code as u16)?; diff --git a/matter/src/sys/mod.rs b/matter/src/sys/mod.rs index e8e59cb..9b5219e 100644 --- a/matter/src/sys/mod.rs +++ b/matter/src/sys/mod.rs @@ -25,7 +25,8 @@ mod sys_linux; #[cfg(target_os = "linux")] pub use self::sys_linux::*; -#[cfg(any(target_os = "macos", target_os = "linux"))] -mod posix; -#[cfg(any(target_os = "macos", target_os = "linux"))] -pub use self::posix::*; +pub const SPAKE2_ITERATION_COUNT: u32 = 2000; + +// The Packet Pool that is allocated from. POSIX systems can use +// higher values unlike embedded systems +pub const MAX_PACKET_POOL_SIZE: usize = 25; diff --git a/matter/src/sys/posix.rs b/matter/src/sys/posix.rs deleted file mode 100644 index 2736e51..0000000 --- a/matter/src/sys/posix.rs +++ /dev/null @@ -1,96 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::{ - convert::TryInto, - fs::{remove_file, DirBuilder, File}, - io::{Read, Write}, - sync::{Arc, Mutex, Once}, -}; - -use crate::error::Error; - -pub const SPAKE2_ITERATION_COUNT: u32 = 2000; - -// The Packet Pool that is allocated from. POSIX systems can use -// higher values unlike embedded systems -pub const MAX_PACKET_POOL_SIZE: usize = 25; - -pub struct Psm {} - -static mut G_PSM: Option>> = None; -static INIT: Once = Once::new(); - -const PSM_DIR: &str = "/tmp/matter_psm"; - -macro_rules! psm_path { - ($key:ident) => { - format!("{}/{}", PSM_DIR, $key) - }; -} - -impl Psm { - fn new() -> Result { - let result = DirBuilder::new().create(PSM_DIR); - if let Err(e) = result { - if e.kind() != std::io::ErrorKind::AlreadyExists { - return Err(e.into()); - } - } - - Ok(Self {}) - } - - pub fn get() -> Result>, Error> { - unsafe { - INIT.call_once(|| { - G_PSM = Some(Arc::new(Mutex::new(Psm::new().unwrap()))); - }); - Ok(G_PSM.as_ref().ok_or(Error::Invalid)?.clone()) - } - } - - pub fn set_kv_slice(&self, key: &str, val: &[u8]) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(val)?; - Ok(()) - } - - pub fn get_kv_slice(&self, key: &str, val: &mut Vec) -> Result { - let mut f = File::open(psm_path!(key))?; - let len = f.read_to_end(val)?; - Ok(len) - } - - pub fn set_kv_u64(&self, key: &str, val: u64) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(&val.to_be_bytes())?; - Ok(()) - } - - pub fn get_kv_u64(&self, key: &str, val: &mut u64) -> Result<(), Error> { - let mut f = File::open(psm_path!(key))?; - let mut vec = Vec::new(); - let _ = f.read_to_end(&mut vec)?; - *val = u64::from_be_bytes(vec.as_slice().try_into()?); - Ok(()) - } - - pub fn rm(&self, key: &str) { - let _ = remove_file(psm_path!(key)); - } -} diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index a9b8b87..f8b9716 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -18,8 +18,8 @@ use crate::error::Error; use byteorder::{ByteOrder, LittleEndian}; +use core::fmt; use log::{error, info}; -use std::fmt; use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; @@ -318,7 +318,7 @@ impl<'a> PartialEq for TLVElement<'a> { loop { let ours = our_iter.next(); let theirs = their.next(); - if std::mem::discriminant(&ours) != std::mem::discriminant(&theirs) { + if core::mem::discriminant(&ours) != core::mem::discriminant(&theirs) { // One of us reached end of list, but the other didn't, that's a mismatch return false; } @@ -341,8 +341,8 @@ impl<'a> PartialEq for TLVElement<'a> { // Only compare the discriminants in case of array/list/structures, // instead of actual element values. Those will be subsets within this same // list that will get validated anyway - if std::mem::discriminant(&ours.element_type) - != std::mem::discriminant(&theirs.element_type) + if core::mem::discriminant(&ours.element_type) + != core::mem::discriminant(&theirs.element_type) { return false; } @@ -438,6 +438,18 @@ impl<'a> TLVElement<'a> { } } + pub fn str(&self) -> Result<&'a str, Error> { + match self.element_type { + ElementType::Str8l(s) + | ElementType::Utf8l(s) + | ElementType::Str16l(s) + | ElementType::Utf16l(s) => { + Ok(core::str::from_utf8(s).map_err(|_| Error::InvalidData)?) + } + _ => Err(Error::TLVTypeMismatch), + } + } + pub fn bool(&self) -> Result { match self.element_type { ElementType::False => Ok(false), @@ -522,7 +534,7 @@ impl<'a> fmt::Display for TLVElement<'a> { | ElementType::Utf8l(a) | ElementType::Str16l(a) | ElementType::Utf16l(a) => { - if let Ok(s) = std::str::from_utf8(a) { + if let Ok(s) = core::str::from_utf8(a) { write!(f, "len[{}]\"{}\"", s.len(), s) } else { write!(f, "len[{}]{:x?}", a.len(), a) @@ -752,7 +764,7 @@ pub fn print_tlv_list(b: &[u8]) { match a.element_type { ElementType::Struct(_) => { if index < MAX_DEPTH { - println!("{}{}", space[index], a); + info!("{}{}", space[index], a); stack[index] = '}'; index += 1; } else { @@ -761,7 +773,7 @@ pub fn print_tlv_list(b: &[u8]) { } ElementType::Array(_) | ElementType::List(_) => { if index < MAX_DEPTH { - println!("{}{}", space[index], a); + info!("{}{}", space[index], a); stack[index] = ']'; index += 1; } else { @@ -771,19 +783,21 @@ pub fn print_tlv_list(b: &[u8]) { ElementType::EndCnt => { if index > 0 { index -= 1; - println!("{}{}", space[index], stack[index]); + info!("{}{}", space[index], stack[index]); } else { error!("Incorrect TLV List"); } } - _ => println!("{}{}", space[index], a), + _ => info!("{}{}", space[index], a), } } - println!("---------"); + info!("---------"); } #[cfg(test)] mod tests { + use log::info; + use super::{ get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, TagType, @@ -1105,7 +1119,7 @@ mod tests { .unwrap() .enter() .unwrap(); - println!("Command list iterator: {:?}", cmd_list_iter); + info!("Command list iterator: {:?}", cmd_list_iter); // This is an array of CommandDataIB, but we'll only use the first element let cmd_data_ib = cmd_list_iter.next().unwrap(); @@ -1203,8 +1217,8 @@ mod tests { Some(a) => { assert_eq!(a.tag_type, verify_matrix[index].0); assert_eq!( - std::mem::discriminant(&a.element_type), - std::mem::discriminant(&verify_matrix[index].1) + core::mem::discriminant(&a.element_type), + core::mem::discriminant(&verify_matrix[index].1) ); } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 2d3cedd..c7b5e35 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -17,9 +17,13 @@ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use crate::error::Error; +use alloc::borrow::ToOwned; +use alloc::{string::String, vec::Vec}; +use core::fmt::Debug; use core::slice::Iter; use log::error; -use std::fmt::Debug; + +extern crate alloc; pub trait FromTLV<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result @@ -76,6 +80,15 @@ pub trait ToTLV { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; } +impl ToTLV for &T +where + T: ToTLV, +{ + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + (**self).to_tlv(tw, tag) + } +} + macro_rules! totlv_for { ($($t:ident)*) => { $( @@ -116,10 +129,14 @@ totlv_for!(i8 u8 u16 u32 u64 bool); pub struct UtfStr<'a>(pub &'a [u8]); impl<'a> UtfStr<'a> { - pub fn new(str: &'a [u8]) -> Self { + pub const fn new(str: &'a [u8]) -> Self { Self(str) } + pub fn as_str(&self) -> Result<&str, Error> { + core::str::from_utf8(self.0).map_err(|_| Error::Invalid) + } + pub fn to_string(self) -> Result { String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) } @@ -396,7 +413,7 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { } impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { for i in self.iter() { writeln!(f, "{:?}", i)?; } @@ -442,9 +459,8 @@ mod tests { } #[test] fn test_derive_totlv() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); let abc = TestDerive { @@ -525,9 +541,8 @@ mod tests { #[test] fn test_derive_totlv_fab_scoped() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 }; @@ -557,9 +572,8 @@ mod tests { enum_val = TestDeriveEnum::ValueB(10); // Test ToTLV - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap(); diff --git a/matter/src/tlv/writer.rs b/matter/src/tlv/writer.rs index cdf914a..1db8421 100644 --- a/matter/src/tlv/writer.rs +++ b/matter/src/tlv/writer.rs @@ -50,11 +50,11 @@ enum WriteElementType { } pub struct TLVWriter<'a, 'b> { - buf: &'b mut WriteBuf<'a>, + buf: &'a mut WriteBuf<'b>, } impl<'a, 'b> TLVWriter<'a, 'b> { - pub fn new(buf: &'b mut WriteBuf<'a>) -> Self { + pub fn new(buf: &'a mut WriteBuf<'b>) -> Self { TLVWriter { buf } } @@ -265,7 +265,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { self.buf.rewind_tail_to(anchor); } - pub fn get_buf<'c>(&'c mut self) -> &'c mut WriteBuf<'a> { + pub fn get_buf(&mut self) -> &mut WriteBuf<'b> { self.buf } } @@ -277,9 +277,8 @@ mod tests { #[test] fn test_write_success() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.start_struct(TagType::Anonymous).unwrap(); @@ -299,9 +298,8 @@ mod tests { #[test] fn test_write_overflow() { - let mut buf: [u8; 6] = [0; 6]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 6]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Anonymous, 12).unwrap(); @@ -317,9 +315,8 @@ mod tests { #[test] fn test_put_str8() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); @@ -334,9 +331,8 @@ mod tests { #[test] fn test_put_str16_as() { - let mut buf: [u8; 20] = [0; 20]; - let buf_len = buf.len(); - let mut writebuf = WriteBuf::new(&mut buf, buf_len); + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index ae5711d..e668a8d 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -15,43 +15,46 @@ * limitations under the License. */ -use boxslab::{BoxSlab, Slab}; use colored::*; +use core::any::Any; +use core::fmt; +use core::time::Duration; use log::{error, info, trace}; -use std::any::Any; -use std::fmt; -use std::time::SystemTime; use crate::error::Error; use crate::secure_channel; +use crate::secure_channel::case::CaseSession; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use heapless::LinearMap; -use super::packet::PacketPool; use super::session::CloneData; use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; pub struct ExchangeCtx<'a> { pub exch: &'a mut Exchange, pub sess: SessionHandle<'a>, + pub epoch: Epoch, } -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +impl<'a> ExchangeCtx<'a> { + pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { + self.exch.send(tx, &mut self.sess) + } +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)] pub enum Role { + #[default] Initiator = 0, Responder = 1, } -impl Default for Role { - fn default() -> Self { - Role::Initiator - } -} - -/// State of the exchange -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] enum State { /// The exchange is open and active + #[default] Open, /// The exchange is closed, but keys are active since retransmissions/acks may be pending Close, @@ -59,28 +62,17 @@ enum State { Terminate, } -impl Default for State { - fn default() -> Self { - State::Open - } -} - // Instead of just doing an Option<>, we create some special handling // where the commonly used higher layer data store does't have to do a Box -#[derive(Debug)] +#[derive(Default)] pub enum DataOption { - Boxed(Box), - Time(SystemTime), + CaseSession(CaseSession), + Time(Duration), + #[default] None, } -impl Default for DataOption { - fn default() -> Self { - DataOption::None - } -} - -#[derive(Debug, Default)] +#[derive(Default)] pub struct Exchange { id: u16, sess_idx: usize, @@ -136,48 +128,48 @@ impl Exchange { matches!(self.data, DataOption::None) } - pub fn set_data_boxed(&mut self, data: Box) { - self.data = DataOption::Boxed(data); + pub fn set_case_session(&mut self, session: CaseSession) { + self.data = DataOption::CaseSession(session); } - pub fn clear_data_boxed(&mut self) { + pub fn clear_data(&mut self) { self.data = DataOption::None; } - pub fn get_data_boxed(&mut self) -> Option<&mut T> { - if let DataOption::Boxed(a) = &mut self.data { - a.downcast_mut::() + pub fn get_case_session(&mut self) -> Option<&mut CaseSession> { + if let DataOption::CaseSession(session) = &mut self.data { + Some(session) } else { None } } - pub fn take_data_boxed(&mut self) -> Option> { - let old = std::mem::replace(&mut self.data, DataOption::None); - if let DataOption::Boxed(d) = old { - d.downcast::().ok() + pub fn take_case_session(&mut self) -> Option { + let old = core::mem::replace(&mut self.data, DataOption::None); + if let DataOption::CaseSession(session) = old { + Some(session) } else { self.data = old; None } } - pub fn set_data_time(&mut self, expiry_ts: Option) { + pub fn set_data_time(&mut self, expiry_ts: Option) { if let Some(t) = expiry_ts { self.data = DataOption::Time(t); } } - pub fn get_data_time(&self) -> Option { + pub fn get_data_time(&self) -> Option { match self.data { DataOption::Time(t) => Some(t), _ => None, } } - pub fn send( + pub(crate) fn send( &mut self, - mut proto_tx: BoxSlab, + tx: &mut Packet, session: &mut SessionHandle, ) -> Result<(), Error> { if self.state == State::Terminate { @@ -185,22 +177,22 @@ impl Exchange { return Ok(()); } - trace!("payload: {:x?}", proto_tx.as_borrow_slice()); + trace!("payload: {:x?}", tx.as_mut_slice()); info!( "{} with proto id: {} opcode: {}", "Sending".blue(), - proto_tx.get_proto_id(), - proto_tx.get_proto_opcode(), + tx.get_proto_id(), + tx.get_proto_opcode(), ); - proto_tx.proto.exch_id = self.id; + tx.proto.exch_id = self.id; if self.role == Role::Initiator { - proto_tx.proto.set_initiator(); + tx.proto.set_initiator(); } - session.pre_send(&mut proto_tx)?; - self.mrp.pre_send(&mut proto_tx)?; - session.send(proto_tx) + session.pre_send(tx)?; + self.mrp.pre_send(tx)?; + session.send(tx) } } @@ -208,8 +200,8 @@ impl fmt::Display for Exchange { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "exch_id: {:?}, sess_index: {}, role: {:?}, data: {:?}, mrp: {:?}, state: {:?}", - self.id, self.sess_idx, self.role, self.data, self.mrp, self.state + "exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}", + self.id, self.sess_idx, self.role, self.mrp, self.state ) } } @@ -232,20 +224,21 @@ pub fn get_complementary_role(is_initiator: bool) -> Role { const MAX_EXCHANGES: usize = 8; -#[derive(Default)] pub struct ExchangeMgr { // keys: exch-id exchanges: LinearMap, sess_mgr: SessionMgr, + epoch: Epoch, } pub const MAX_MRP_ENTRIES: usize = 4; impl ExchangeMgr { - pub fn new(sess_mgr: SessionMgr) -> Self { + pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { - sess_mgr, - exchanges: Default::default(), + sess_mgr: SessionMgr::new(epoch, rand), + exchanges: LinearMap::new(), + epoch, } } @@ -300,45 +293,33 @@ impl ExchangeMgr { } /// The Exchange Mgr receive is like a big processing function - pub fn recv(&mut self) -> Result, ExchangeCtx)>, Error> { + pub fn recv(&mut self, rx: &mut Packet) -> Result, Error> { // Get the session - let (mut proto_rx, index) = self.sess_mgr.recv()?; - - let index = if let Some(s) = index { - s - } else { - // The sessions were full, evict one session, and re-perform post-recv - let evict_index = self.sess_mgr.get_lru(); - self.evict_session(evict_index)?; - info!("Reattempting session creation"); - self.sess_mgr.post_recv(&proto_rx)?.ok_or(Error::Invalid)? - }; + let index = self.sess_mgr.post_recv(rx)?; let mut session = self.sess_mgr.get_session_handle(index); // Decrypt the message - session.recv(&mut proto_rx)?; + session.recv(self.epoch, rx)?; // Get the exchange let exch = ExchangeMgr::_get( &mut self.exchanges, index, - proto_rx.proto.exch_id, - get_complementary_role(proto_rx.proto.is_initiator()), + rx.proto.exch_id, + get_complementary_role(rx.proto.is_initiator()), // We create a new exchange, only if the peer is the initiator - proto_rx.proto.is_initiator(), + rx.proto.is_initiator(), )?; // Message Reliability Protocol - exch.mrp.recv(&proto_rx)?; + exch.mrp.recv(rx, self.epoch)?; if exch.is_state_open() { - Ok(Some(( - proto_rx, - ExchangeCtx { - exch, - sess: session, - }, - ))) + Ok(Some(ExchangeCtx { + exch, + sess: session, + epoch: self.epoch, + })) } else { // Instead of an error, we send None here, because it is likely that // we just processed an acknowledgement that cleared the exchange @@ -346,11 +327,11 @@ impl ExchangeMgr { } } - pub fn send(&mut self, exch_id: u16, proto_tx: BoxSlab) -> Result<(), Error> { + pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result<(), Error> { let exchange = ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); - exchange.send(proto_tx, &mut session) + exchange.send(tx, &mut session) } pub fn purge(&mut self) { @@ -366,70 +347,66 @@ impl ExchangeMgr { } } - pub fn pending_acks(&mut self, expired_entries: &mut LinearMap) { - for (exch_id, exchange) in self.exchanges.iter() { - if exchange.mrp.is_ack_ready() { - expired_entries.insert(*exch_id, ()).unwrap(); - } - } + pub fn pending_ack(&mut self) -> Option { + self.exchanges + .iter() + .find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch)) + .map(|(exch_id, _)| *exch_id) } - pub fn evict_session(&mut self, index: usize) -> Result<(), Error> { - info!("Sessions full, vacating session with index: {}", index); - // If we enter here, we have an LRU session that needs to be reclaimed - // As per the spec, we need to send a CLOSE here + pub fn evict_session(&mut self, tx: &mut Packet) -> Result { + if let Some(index) = self.sess_mgr.get_session_for_eviction() { + info!("Sessions full, vacating session with index: {}", index); + // If we enter here, we have an LRU session that needs to be reclaimed + // As per the spec, we need to send a CLOSE here - let mut session = self.sess_mgr.get_session_handle(index); - let mut tx = Slab::::try_new(Packet::new_tx()?).ok_or(Error::NoSpace)?; - secure_channel::common::create_sc_status_report( - &mut tx, - secure_channel::common::SCStatusCodes::CloseSession, - None, - )?; + let mut session = self.sess_mgr.get_session_handle(index); + secure_channel::common::create_sc_status_report( + tx, + secure_channel::common::SCStatusCodes::CloseSession, + None, + )?; - if let Some((_, exchange)) = self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) { - // Send Close_session on this exchange, and then close the session - // Should this be done for all exchanges? - error!("Sending Close Session"); - exchange.send(tx, &mut session)?; - // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. + if let Some((_, exchange)) = + self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) + { + // Send Close_session on this exchange, and then close the session + // Should this be done for all exchanges? + error!("Sending Close Session"); + exchange.send(tx, &mut session)?; + // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. + } + + let remove_exchanges: heapless::Vec = self + .exchanges + .iter() + .filter_map(|(eid, e)| { + if e.sess_idx == index { + Some(*eid) + } else { + None + } + }) + .collect(); + info!( + "Terminating the following exchanges: {:?}", + remove_exchanges + ); + for exch_id in remove_exchanges { + // Remove from exchange list + self.exchanges.remove(&exch_id); + } + self.sess_mgr.remove(index); + + Ok(true) + } else { + Ok(false) } - - let remove_exchanges: Vec = self - .exchanges - .iter() - .filter_map(|(eid, e)| { - if e.sess_idx == index { - Some(*eid) - } else { - None - } - }) - .collect(); - info!( - "Terminating the following exchanges: {:?}", - remove_exchanges - ); - for exch_id in remove_exchanges { - // Remove from exchange list - self.exchanges.remove(&exch_id); - } - self.sess_mgr.remove(index); - Ok(()) } pub fn add_session(&mut self, clone_data: &CloneData) -> Result { - let sess_idx = match self.sess_mgr.clone_session(clone_data) { - Ok(idx) => idx, - Err(Error::NoSpace) => { - let evict_index = self.sess_mgr.get_lru(); - self.evict_session(evict_index)?; - self.sess_mgr.clone_session(clone_data)? - } - Err(e) => { - return Err(e); - } - }; + let sess_idx = self.sess_mgr.clone_session(clone_data)?; + Ok(self.sess_mgr.get_session_handle(sess_idx)) } } @@ -449,12 +426,16 @@ impl fmt::Display for ExchangeMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] mod tests { - use crate::{ error::Error, transport::{ - network::{Address, NetworkInterface}, - session::{CloneData, SessionMgr, SessionMode, MAX_SESSIONS}, + network::Address, + packet::Packet, + session::{CloneData, SessionMode, MAX_SESSIONS}, + }, + utils::{ + epoch::{dummy_epoch, sys_epoch}, + rand::dummy_rand, }, }; @@ -462,8 +443,7 @@ mod tests { #[test] fn test_purge() { - let sess_mgr = SessionMgr::new(); - let mut mgr = ExchangeMgr::new(sess_mgr); + let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap(); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap(); @@ -519,33 +499,13 @@ mod tests { } } - pub struct DummyNetwork; - impl DummyNetwork { - pub fn new() -> Self { - Self {} - } - } - - impl NetworkInterface for DummyNetwork { - fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - Ok((0, Address::default())) - } - - fn send(&self, _out_buf: &[u8], _addr: Address) -> Result { - Ok(0) - } - } - #[test] /// We purposefuly overflow the sessions /// and when the overflow happens, we confirm that /// - The sessions are evicted in LRU /// - The exchanges associated with those sessions are evicted too fn test_sess_evict() { - let mut sess_mgr = SessionMgr::new(); - let transport = Box::new(DummyNetwork::new()); - sess_mgr.add_network_interface(transport).unwrap(); - let mut mgr = ExchangeMgr::new(sess_mgr); + let mut mgr = ExchangeMgr::new(sys_epoch, dummy_rand); // TODO fill_sessions(&mut mgr, MAX_SESSIONS + 1); // Sessions are now full from local session id 1 to 16 @@ -568,6 +528,14 @@ mod tests { for i in 1..(MAX_SESSIONS + 1) { // Now purposefully overflow the sessions by adding another session + let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); + assert!(matches!(result, Err(Error::NoSpace))); + + let mut buf = [0; 1500]; + let tx = &mut Packet::new_tx(&mut buf); + let evicted = mgr.evict_session(tx).unwrap(); + assert!(evicted); + let session = mgr .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) .unwrap(); diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 76c506c..349cfde 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -15,161 +15,210 @@ * limitations under the License. */ -use async_channel::Receiver; -use boxslab::{BoxSlab, Slab}; -use heapless::LinearMap; -use log::{debug, error, info}; +use core::borrow::Borrow; +use core::cell::RefCell; + +use log::info; use crate::error::*; +use crate::fabric::FabricMgr; +use crate::mdns::MdnsMgr; +use crate::secure_channel::pake::PaseMgr; +use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL; +use crate::secure_channel::core::SecureChannel; use crate::transport::mrp::ReliableMessage; -use crate::transport::packet::PacketPool; -use crate::transport::{exchange, packet::Packet, proto_demux, queue, session, udp}; +use crate::transport::{exchange, packet::Packet}; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; -use super::proto_demux::ProtoCtx; -use super::queue::Msg; +use super::proto_ctx::ProtoCtx; -pub struct Mgr { - exch_mgr: exchange::ExchangeMgr, - proto_demux: proto_demux::ProtoDemux, - rx_q: Receiver, +#[derive(Copy, Clone, Eq, PartialEq)] +enum RecvState { + New, + OpenExchange, + EvictSession, + Ack, } -impl Mgr { - pub fn new() -> Result { - let mut sess_mgr = session::SessionMgr::new(); - let udp_transport = Box::new(udp::UdpListener::new()?); - sess_mgr.add_network_interface(udp_transport)?; - Ok(Mgr { - proto_demux: proto_demux::ProtoDemux::new(), - exch_mgr: exchange::ExchangeMgr::new(sess_mgr), - rx_q: queue::WorkQ::init()?, - }) - } +pub enum RecvAction<'r, 'p> { + Send(&'r [u8]), + Interact(ProtoCtx<'r, 'p>), +} - // Allows registration of different protocols with the Transport/Protocol Demux - pub fn register_protocol( - &mut self, - proto_id_handle: Box, - ) -> Result<(), Error> { - self.proto_demux.register(proto_id_handle) - } +pub struct RecvCompletion<'r, 'a, 'p> { + mgr: &'r mut TransportMgr<'a>, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + state: RecvState, +} - fn send_to_exchange( - &mut self, - exch_id: u16, - proto_tx: BoxSlab, - ) -> Result<(), Error> { - self.exch_mgr.send(exch_id, proto_tx) - } - - fn handle_rxtx(&mut self) -> Result<(), Error> { - let result = self.exch_mgr.recv().map_err(|e| { - error!("Error in recv: {:?}", e); - e - })?; - - if result.is_none() { - // Nothing to process, return quietly - return Ok(()); - } - // result contains something worth processing, we can safely unwrap - // as we already checked for none above - let (rx, exch_ctx) = result.unwrap(); - - debug!("Exchange is {:?}", exch_ctx.exch); - let tx = Self::new_tx()?; - - let mut proto_ctx = ProtoCtx::new(exch_ctx, rx, tx); - // Proto Dispatch - match self.proto_demux.handle(&mut proto_ctx) { - Ok(r) => { - if let proto_demux::ResponseRequired::No = r { - // We need to send the Ack if reliability is enabled, in this case - return Ok(()); - } - } - Err(e) => { - error!("Error in proto_demux {:?}", e); - return Err(e); - } - } - - let ProtoCtx { - exch_ctx, - rx: _, - tx, - } = proto_ctx; - - // tx_ctx now contains the response payload, send the packet - let exch_id = exch_ctx.exch.get_id(); - self.send_to_exchange(exch_id, tx).map_err(|e| { - error!("Error in sending msg {:?}", e); - e - })?; - - Ok(()) - } - - fn handle_queue_msgs(&mut self) -> Result<(), Error> { - if let Ok(msg) = self.rx_q.try_recv() { - match msg { - Msg::NewSession(clone_data) => { - // If a new session was created, add it - let _ = self - .exch_mgr - .add_session(&clone_data) - .map_err(|e| error!("Error adding new session {:?}", e)); - } - _ => { - error!("Queue Message Type not yet handled {:?}", msg); - } - } - } - Ok(()) - } - - pub fn start(&mut self) -> Result<(), Error> { +impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { + pub fn next_action(&mut self) -> Result>, Error> { loop { - // Handle network operations - if self.handle_rxtx().is_err() { - error!("Error in handle_rxtx"); - continue; + // Polonius will remove the need for unsafe one day + let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; + + if let Some(action) = this.maybe_next_action()? { + return Ok(action); } - - if self.handle_queue_msgs().is_err() { - error!("Error in handle_queue_msg"); - continue; - } - - // Handle any pending acknowledgement send - let mut acks_to_send: LinearMap = - LinearMap::new(); - self.exch_mgr.pending_acks(&mut acks_to_send); - for exch_id in acks_to_send.keys() { - info!("Sending MRP Standalone ACK for exch {}", exch_id); - let mut proto_tx = match Self::new_tx() { - Ok(p) => p, - Err(e) => { - error!("Error creating proto_tx {:?}", e); - break; - } - }; - ReliableMessage::prepare_ack(*exch_id, &mut proto_tx); - if let Err(e) = self.send_to_exchange(*exch_id, proto_tx) { - error!("Error in sending Ack {:?}", e); - } - } - - // Handle exchange purging - // This need not be done in each turn of the loop, maybe once in 5 times or so? - self.exch_mgr.purge(); - - info!("Exchange Mgr: {}", self.exch_mgr); } } - fn new_tx() -> Result, Error> { - Slab::::try_new(Packet::new_tx()?).ok_or(Error::PacketPoolExhaust) + fn maybe_next_action(&mut self) -> Result>>, Error> { + self.mgr.exch_mgr.purge(); + + match self.state { + RecvState::New => { + self.mgr.exch_mgr.get_sess_mgr().decode(self.rx)?; + self.state = RecvState::OpenExchange; + Ok(None) + } + RecvState::OpenExchange => match self.mgr.exch_mgr.recv(self.rx) { + Ok(Some(exch_ctx)) => { + if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { + let mut proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); + + if self.mgr.secure_channel.handle(&mut proto_ctx)? { + proto_ctx.send()?; + + self.state = RecvState::Ack; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } else { + self.state = RecvState::Ack; + Ok(None) + } + } else { + let proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); + self.state = RecvState::Ack; + + Ok(Some(Some(RecvAction::Interact(proto_ctx)))) + } + } + Ok(None) => { + self.state = RecvState::Ack; + Ok(None) + } + Err(Error::NoSpace) => { + self.state = RecvState::EvictSession; + Ok(None) + } + Err(err) => Err(err), + }, + RecvState::EvictSession => { + self.mgr.exch_mgr.evict_session(self.tx)?; + self.state = RecvState::OpenExchange; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } + RecvState::Ack => { + if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { + info!("Sending MRP Standalone ACK for exch {}", exch_id); + + ReliableMessage::prepare_ack(exch_id, self.tx); + + self.mgr.exch_mgr.send(exch_id, self.tx)?; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } else { + Ok(Some(None)) + } + } + } } } + +#[derive(Copy, Clone, Eq, PartialEq)] +enum NotifyState {} + +pub enum NotifyAction<'r, 'p> { + Send(&'r [u8]), + Notify(ProtoCtx<'r, 'p>), +} + +pub struct NotifyCompletion<'r, 'a, 'p> { + // TODO + _mgr: &'r mut TransportMgr<'a>, + _rx: &'r mut Packet<'p>, + _tx: &'r mut Packet<'p>, + _state: NotifyState, +} + +impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { + pub fn next_action(&mut self) -> Result>, Error> { + loop { + // Polonius will remove the need for unsafe one day + let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; + + if let Some(action) = this.maybe_next_action()? { + return Ok(action); + } + } + } + + fn maybe_next_action(&mut self) -> Result>>, Error> { + Ok(Some(None)) // TODO: Future + } +} + +pub struct TransportMgr<'a> { + exch_mgr: exchange::ExchangeMgr, + secure_channel: SecureChannel<'a>, +} + +impl<'a> TransportMgr<'a> { + pub fn new< + T: Borrow> + Borrow> + Borrow + Borrow, + >( + matter: &'a T, + mdns_mgr: &'a RefCell>, + ) -> Self { + Self::wrap( + SecureChannel::new(matter.borrow(), matter.borrow(), mdns_mgr, *matter.borrow()), + *matter.borrow(), + *matter.borrow(), + ) + } + + pub fn wrap(secure_channel: SecureChannel<'a>, epoch: Epoch, rand: Rand) -> Self { + Self { + exch_mgr: exchange::ExchangeMgr::new(epoch, rand), + secure_channel, + } + } + + pub fn recv<'r, 'p>( + &'r mut self, + rx: &'r mut Packet<'p>, + tx: &'r mut Packet<'p>, + ) -> RecvCompletion<'r, 'a, 'p> { + RecvCompletion { + mgr: self, + rx, + tx, + state: RecvState::New, + } + } + + pub fn notify(&mut self, _tx: &mut Packet) -> Result { + Ok(false) + } + + // async fn handle_queue_msgs(&mut self) -> Result<(), Error> { + // if let Ok(msg) = self.rx_q.try_recv() { + // match msg { + // Msg::NewSession(clone_data) => { + // // If a new session was created, add it + // let _ = self + // .exch_mgr + // .add_session(&clone_data) + // .await + // .map_err(|e| error!("Error adding new session {:?}", e)); + // } + // _ => { + // error!("Queue Message Type not yet handled {:?}", msg); + // } + // } + // } + // Ok(()) + // } +} diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index b3a2545..43acccd 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -22,7 +22,7 @@ pub mod mrp; pub mod network; pub mod packet; pub mod plain_hdr; -pub mod proto_demux; +pub mod proto_ctx; pub mod proto_hdr; pub mod queue; pub mod session; diff --git a/matter/src/transport/mrp.rs b/matter/src/transport/mrp.rs index 22cf9ad..2213a52 100644 --- a/matter/src/transport/mrp.rs +++ b/matter/src/transport/mrp.rs @@ -15,8 +15,8 @@ * limitations under the License. */ -use std::time::Duration; -use std::time::SystemTime; +use crate::utils::epoch::Epoch; +use core::time::Duration; use crate::{error::*, secure_channel, transport::packet::Packet}; use log::error; @@ -46,13 +46,13 @@ pub struct AckEntry { // The msg counter that we should acknowledge msg_ctr: u32, // The max time after which this entry must be ACK - ack_timeout: SystemTime, + ack_timeout: Duration, } impl AckEntry { - pub fn new(msg_ctr: u32) -> Result { + pub fn new(msg_ctr: u32, epoch: Epoch) -> Result { if let Some(ack_timeout) = - SystemTime::now().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) + epoch().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) { Ok(Self { msg_ctr, @@ -67,8 +67,8 @@ impl AckEntry { self.msg_ctr } - pub fn has_timed_out(&self) -> bool { - self.ack_timeout > SystemTime::now() + pub fn has_timed_out(&self, epoch: Epoch) -> bool { + self.ack_timeout > epoch() } } @@ -90,10 +90,10 @@ impl ReliableMessage { } // Check any pending acknowledgements / retransmissions and take action - pub fn is_ack_ready(&self) -> bool { + pub fn is_ack_ready(&self, epoch: Epoch) -> bool { // Acknowledgements if let Some(ack_entry) = self.ack { - ack_entry.has_timed_out() + ack_entry.has_timed_out(epoch) } else { false } @@ -132,7 +132,7 @@ impl ReliableMessage { * - there can be only one pending retransmission per exchange (so this is per-exchange) * - duplicate detection should happen per session (obviously), so that part is per-session */ - pub fn recv(&mut self, proto_rx: &Packet) -> Result<(), Error> { + pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> { if proto_rx.proto.is_ack() { // Handle received Acks let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; @@ -153,7 +153,7 @@ impl ReliableMessage { return Err(Error::Invalid); } - self.ack = Some(AckEntry::new(proto_rx.plain.ctr)?); + self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?); } Ok(()) } diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index 5b398ca..91645de 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -15,12 +15,8 @@ * limitations under the License. */ -use std::{ - fmt::{Debug, Display}, - net::{IpAddr, Ipv4Addr, SocketAddr}, -}; - -use crate::error::Error; +use core::fmt::{Debug, Display}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; #[derive(PartialEq, Copy, Clone)] pub enum Address { @@ -34,7 +30,7 @@ impl Default for Address { } impl Display for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => writeln!(f, "{}", addr), } @@ -42,14 +38,9 @@ impl Display for Address { } impl Debug for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Address::Udp(addr) => writeln!(f, "{}", addr), } } } - -pub trait NetworkInterface { - fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error>; - fn send(&self, out_buf: &[u8], addr: Address) -> Result; -} diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 18af1b5..e39ac1c 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -99,46 +99,46 @@ pub struct Packet<'a> { pub proto: ProtoHdr, pub peer: Address, data: Direction<'a>, - buffer_index: usize, } impl<'a> Packet<'a> { const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); - pub fn new_rx() -> Result { - let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; - let buf_len = buffer.len(); - Ok(Self { + pub fn new_rx(buf: &'a mut [u8]) -> Self { + Self { plain: Default::default(), proto: Default::default(), - buffer_index, peer: Address::default(), - data: Direction::Rx(ParseBuf::new(buffer, buf_len), RxState::Uninit), - }) + data: Direction::Rx(ParseBuf::new(buf), RxState::Uninit), + } } - pub fn new_tx() -> Result { - let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; - let buf_len = buffer.len(); + pub fn new_tx(buf: &'a mut [u8]) -> Self { + let mut wb = WriteBuf::new(buf); + wb.reserve(Packet::HDR_RESERVE).unwrap(); - let mut wb = WriteBuf::new(buffer, buf_len); - wb.reserve(Packet::HDR_RESERVE)?; + // Reliability on by default + let mut proto: ProtoHdr = Default::default(); + proto.set_reliable(); - let mut p = Self { + Self { plain: Default::default(), - proto: Default::default(), - buffer_index, + proto, peer: Address::default(), data: Direction::Tx(wb), - }; - // Reliability on by default - p.proto.set_reliable(); - Ok(p) + } } - pub fn as_borrow_slice(&mut self) -> &mut [u8] { + pub fn as_slice(&self) -> &[u8] { + match &self.data { + Direction::Rx(pb, _) => pb.as_slice(), + Direction::Tx(wb) => wb.as_slice(), + } + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { match &mut self.data { - Direction::Rx(pb, _) => pb.as_borrow_slice(), + Direction::Rx(pb, _) => pb.as_mut_slice(), Direction::Tx(wb) => wb.as_mut_slice(), } } @@ -229,11 +229,4 @@ impl<'a> Packet<'a> { } } -impl<'a> Drop for Packet<'a> { - fn drop(&mut self) { - BufferPool::free(self.buffer_index); - trace!("Dropping Packet......"); - } -} - box_slab!(PacketPool, Packet<'static>, MAX_PACKET_POOL_SIZE); diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index 5e54cd1..e51ddaf 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -21,18 +21,13 @@ use crate::utils::writebuf::WriteBuf; use bitflags::bitflags; use log::info; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] pub enum SessionType { + #[default] None, Encrypted, } -impl Default for SessionType { - fn default() -> SessionType { - SessionType::None - } -} - bitflags! { #[derive(Default)] pub struct MsgFlags: u8 { diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs new file mode 100644 index 0000000..747a1e6 --- /dev/null +++ b/matter/src/transport/proto_ctx.rs @@ -0,0 +1,43 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::error::Error; + +use super::exchange::ExchangeCtx; +use super::packet::Packet; + +/// This is the context in which a receive packet is being processed +pub struct ProtoCtx<'a, 'b> { + /// This is the exchange context, that includes the exchange and the session + pub exch_ctx: ExchangeCtx<'a>, + /// This is the received buffer for this transaction + pub rx: &'a Packet<'b>, + /// This is the transmit buffer for this transaction + pub tx: &'a mut Packet<'b>, +} + +impl<'a, 'b> ProtoCtx<'a, 'b> { + pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self { + Self { exch_ctx, rx, tx } + } + + pub fn send(&mut self) -> Result<&[u8], Error> { + self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)?; + + Ok(self.tx.as_mut_slice()) + } +} diff --git a/matter/src/transport/proto_demux.rs b/matter/src/transport/proto_demux.rs deleted file mode 100644 index 263ffc9..0000000 --- a/matter/src/transport/proto_demux.rs +++ /dev/null @@ -1,95 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use boxslab::BoxSlab; - -use crate::error::*; - -use super::exchange::ExchangeCtx; -use super::packet::PacketPool; - -const MAX_PROTOCOLS: usize = 4; - -#[derive(PartialEq, Debug)] -pub enum ResponseRequired { - Yes, - No, -} -pub struct ProtoDemux { - proto_id_handlers: [Option>; MAX_PROTOCOLS], -} - -/// This is the context in which a receive packet is being processed -pub struct ProtoCtx<'a> { - /// This is the exchange context, that includes the exchange and the session - pub exch_ctx: ExchangeCtx<'a>, - /// This is the received buffer for this transaction - pub rx: BoxSlab, - /// This is the transmit buffer for this transaction - pub tx: BoxSlab, -} - -impl<'a> ProtoCtx<'a> { - pub fn new( - exch_ctx: ExchangeCtx<'a>, - rx: BoxSlab, - tx: BoxSlab, - ) -> Self { - Self { exch_ctx, rx, tx } - } -} - -pub trait HandleProto { - fn handle_proto_id(&mut self, proto_ctx: &mut ProtoCtx) -> Result; - - fn get_proto_id(&self) -> usize; - - fn handle_session_event(&self) -> Result<(), Error> { - Ok(()) - } -} - -impl Default for ProtoDemux { - fn default() -> Self { - Self::new() - } -} - -impl ProtoDemux { - pub fn new() -> ProtoDemux { - ProtoDemux { - proto_id_handlers: [None, None, None, None], - } - } - - pub fn register(&mut self, proto_id_handle: Box) -> Result<(), Error> { - let proto_id = proto_id_handle.get_proto_id(); - self.proto_id_handlers[proto_id] = Some(proto_id_handle); - Ok(()) - } - - pub fn handle(&mut self, proto_ctx: &mut ProtoCtx) -> Result { - let proto_id = proto_ctx.rx.get_proto_id() as usize; - if proto_id >= MAX_PROTOCOLS { - return Err(Error::Invalid); - } - return self.proto_id_handlers[proto_id] - .as_mut() - .ok_or(Error::NoHandler)? - .handle_proto_id(proto_ctx); - } -} diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index 3eb8570..96928ac 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -16,7 +16,7 @@ */ use bitflags::bitflags; -use std::fmt; +use core::fmt; use crate::transport::plain_hdr; use crate::utils::parsebuf::ParseBuf; @@ -117,7 +117,7 @@ impl ProtoHdr { if self.is_ack() { self.ack_msg_ctr = Some(parsebuf.le_u32()?); } - trace!("[rx payload]: {:x?}", parsebuf.as_borrow_slice()); + trace!("[rx payload]: {:x?}", parsebuf.as_mut_slice()); Ok(()) } @@ -139,21 +139,21 @@ impl ProtoHdr { impl fmt::Display for ProtoHdr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut flag_str: String = "".to_owned(); + let mut flag_str = heapless::String::<16>::new(); if self.is_vendor() { - flag_str.push_str("V|"); + flag_str.push_str("V|").unwrap(); } if self.is_security_ext() { - flag_str.push_str("SX|"); + flag_str.push_str("SX|").unwrap(); } if self.is_reliable() { - flag_str.push_str("R|"); + flag_str.push_str("R|").unwrap(); } if self.is_ack() { - flag_str.push_str("A|"); + flag_str.push_str("A|").unwrap(); } if self.is_initiator() { - flag_str.push_str("I|"); + flag_str.push_str("I|").unwrap(); } write!( f, @@ -165,7 +165,7 @@ impl fmt::Display for ProtoHdr { fn get_iv(recvd_ctr: u32, peer_nodeid: u64, iv: &mut [u8]) -> Result<(), Error> { // The IV is the source address (64-bit) followed by the message counter (32-bit) - let mut write_buf = WriteBuf::new(iv, iv.len()); + let mut write_buf = WriteBuf::new(iv); // For some reason, this is 0 in the 'bypass' mode write_buf.le_u8(0)?; write_buf.le_u32(recvd_ctr)?; @@ -224,7 +224,7 @@ fn decrypt_in_place( let mut iv = [0_u8; crypto::AEAD_NONCE_LEN_BYTES]; get_iv(recvd_ctr, peer_nodeid, &mut iv)?; - let cipher_text = parsebuf.as_borrow_slice(); + let cipher_text = parsebuf.as_mut_slice(); //println!("AAD: {:x?}", aad); //println!("Cipher Text: {:x?}", cipher_text); //println!("IV: {:x?}", iv); @@ -266,8 +266,7 @@ mod tests { 0x1f, 0xb0, 0x5e, 0xbe, 0xb5, 0x10, 0xad, 0xc6, 0x78, 0x94, 0x50, 0xe5, 0xd2, 0xe0, 0x80, 0xef, 0xa8, 0x3a, 0xf0, 0xa6, 0xaf, 0x1b, 0x2, 0x35, 0xa7, 0xd1, 0xc6, 0x32, ]; - let input_buf_len = input_buf.len(); - let mut parsebuf = ParseBuf::new(&mut input_buf, input_buf_len); + let mut parsebuf = ParseBuf::new(&mut input_buf); let key = [ 0x66, 0x63, 0x31, 0x97, 0x43, 0x9c, 0x17, 0xb9, 0x7e, 0x10, 0xee, 0x47, 0xc8, 0x8, 0x80, 0x4a, @@ -279,7 +278,7 @@ mod tests { decrypt_in_place(recvd_ctr, 0, &mut parsebuf, &key).unwrap(); assert_eq!( - parsebuf.as_slice(), + parsebuf.into_slice(), [ 0x5, 0x8, 0x70, 0x0, 0x1, 0x0, 0x15, 0x28, 0x0, 0x28, 0x1, 0x36, 0x2, 0x15, 0x37, 0x0, 0x24, 0x0, 0x0, 0x24, 0x1, 0x30, 0x24, 0x2, 0x2, 0x18, 0x35, 0x1, 0x24, 0x0, @@ -295,8 +294,7 @@ mod tests { let send_ctr = 41; let mut main_buf: [u8; 52] = [0; 52]; - let main_buf_len = main_buf.len(); - let mut writebuf = WriteBuf::new(&mut main_buf, main_buf_len); + let mut writebuf = WriteBuf::new(&mut main_buf); let plain_hdr: [u8; 8] = [0x0, 0x11, 0x0, 0x0, 0x29, 0x0, 0x0, 0x0]; @@ -313,7 +311,7 @@ mod tests { encrypt_in_place(send_ctr, 0, &plain_hdr, &mut writebuf, &key).unwrap(); assert_eq!( - writebuf.as_slice(), + writebuf.into_slice(), [ 189, 83, 250, 121, 38, 87, 97, 17, 153, 78, 243, 20, 36, 11, 131, 142, 136, 165, 227, 107, 204, 129, 193, 153, 42, 131, 138, 254, 22, 190, 76, 244, 116, 45, 156, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 8faf813..d4c4985 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -15,11 +15,14 @@ * limitations under the License. */ +use crate::data_model::sdm::noc::NocData; +use crate::utils::epoch::Epoch; +use crate::utils::rand::Rand; use core::fmt; -use std::{ +use core::time::Duration; +use core::{ any::Any, ops::{Deref, DerefMut}, - time::SystemTime, }; use crate::{ @@ -27,16 +30,10 @@ use crate::{ transport::{plain_hdr, proto_hdr}, utils::writebuf::WriteBuf, }; -use boxslab::{BoxSlab, Slab}; -use colored::*; use log::{info, trace}; -use rand::Rng; -use super::{ - dedup::RxCtrState, - network::{Address, NetworkInterface}, - packet::{Packet, PacketPool}, -}; +use super::dedup::RxCtrState; +use super::{network::Address, packet::Packet}; pub const MAX_CAT_IDS_PER_NOC: usize = 3; pub type NocCatIds = [u32; MAX_CAT_IDS_PER_NOC]; @@ -58,21 +55,15 @@ impl CaseDetails { } } -#[derive(Debug, PartialEq, Copy, Clone)] +#[derive(Debug, PartialEq, Copy, Clone, Default)] pub enum SessionMode { // The Case session will capture the local fabric index Case(CaseDetails), Pase, + #[default] PlainText, } -impl Default for SessionMode { - fn default() -> Self { - SessionMode::PlainText - } -} - -#[derive(Debug)] pub struct Session { peer_addr: Address, local_nodeid: u64, @@ -87,8 +78,8 @@ pub struct Session { msg_ctr: u32, rx_ctr_state: RxCtrState, mode: SessionMode, - data: Option>, - last_use: SystemTime, + data: Option, + last_use: Duration, } #[derive(Debug)] @@ -103,6 +94,7 @@ pub struct CloneData { peer_addr: Address, mode: SessionMode, } + impl CloneData { pub fn new( local_nodeid: u64, @@ -129,8 +121,8 @@ impl CloneData { const MATTER_MSG_CTR_RANGE: u32 = 0x0fffffff; impl Session { - pub fn new(peer_addr: Address, peer_nodeid: Option) -> Session { - Session { + pub fn new(peer_addr: Address, peer_nodeid: Option, epoch: Epoch, rand: Rand) -> Self { + Self { peer_addr, local_nodeid: 0, peer_nodeid, @@ -139,16 +131,16 @@ impl Session { att_challenge: [0; MATTER_AES128_KEY_SIZE], peer_sess_id: 0, local_sess_id: 0, - msg_ctr: rand::thread_rng().gen_range(0..MATTER_MSG_CTR_RANGE), + msg_ctr: Self::rand_msg_ctr(rand), rx_ctr_state: RxCtrState::new(0), mode: SessionMode::PlainText, data: None, - last_use: SystemTime::now(), + last_use: epoch(), } } // A new encrypted session always clones from a previous 'new' session - pub fn clone(clone_from: &CloneData) -> Session { + pub fn clone(clone_from: &CloneData, epoch: Epoch, rand: Rand) -> Session { Session { peer_addr: clone_from.peer_addr, local_nodeid: clone_from.local_nodeid, @@ -158,28 +150,28 @@ impl Session { att_challenge: clone_from.att_challenge, local_sess_id: clone_from.local_sess_id, peer_sess_id: clone_from.peer_sess_id, - msg_ctr: rand::thread_rng().gen_range(0..MATTER_MSG_CTR_RANGE), + msg_ctr: Self::rand_msg_ctr(rand), rx_ctr_state: RxCtrState::new(0), mode: clone_from.mode, data: None, - last_use: SystemTime::now(), + last_use: epoch(), } } - pub fn set_data(&mut self, data: Box) { + pub fn set_noc_data(&mut self, data: NocData) { self.data = Some(data); } - pub fn clear_data(&mut self) { + pub fn clear_noc_data(&mut self) { self.data = None; } - pub fn get_data(&mut self) -> Option<&mut T> { - self.data.as_mut()?.downcast_mut::() + pub fn get_noc_data(&mut self) -> Option<&mut NocData> { + self.data.as_mut() } - pub fn take_data(&mut self) -> Option> { - self.data.take()?.downcast::().ok() + pub fn take_noc_data(&mut self) -> Option { + self.data.take() } pub fn get_local_sess_id(&self) -> u16 { @@ -252,59 +244,65 @@ impl Session { &self.att_challenge } - pub fn recv(&mut self, proto_rx: &mut Packet) -> Result<(), Error> { - self.last_use = SystemTime::now(); - proto_rx.proto_decode(self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) + pub fn recv(&mut self, epoch: Epoch, rx: &mut Packet) -> Result<(), Error> { + self.last_use = epoch(); + rx.proto_decode(self.peer_nodeid.unwrap_or_default(), self.get_dec_key()) } - pub fn pre_send(&mut self, proto_tx: &mut Packet) -> Result<(), Error> { - proto_tx.plain.sess_id = self.get_peer_sess_id(); - proto_tx.plain.ctr = self.get_msg_ctr(); + pub fn pre_send(&mut self, tx: &mut Packet) -> Result<(), Error> { + tx.plain.sess_id = self.get_peer_sess_id(); + tx.plain.ctr = self.get_msg_ctr(); if self.is_encrypted() { - proto_tx.plain.sess_type = plain_hdr::SessionType::Encrypted; + tx.plain.sess_type = plain_hdr::SessionType::Encrypted; } Ok(()) } // TODO: Most of this can now be moved into the 'Packet' module - fn do_send(&mut self, proto_tx: &mut Packet) -> Result<(), Error> { - self.last_use = SystemTime::now(); - proto_tx.peer = self.peer_addr; + fn do_send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { + self.last_use = epoch(); + tx.peer = self.peer_addr; // Generate encrypted header - let mut tmp_buf: [u8; proto_hdr::max_proto_hdr_len()] = [0; proto_hdr::max_proto_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf[..], proto_hdr::max_proto_hdr_len()); - proto_tx.proto.encode(&mut write_buf)?; - proto_tx.get_writebuf()?.prepend(write_buf.as_slice())?; + let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + tx.proto.encode(&mut write_buf)?; + tx.get_writebuf()?.prepend(write_buf.into_slice())?; // Generate plain-text header if self.mode == SessionMode::PlainText { if let Some(d) = self.peer_nodeid { - proto_tx.plain.set_dest_u64(d); + tx.plain.set_dest_u64(d); } } - let mut tmp_buf: [u8; plain_hdr::max_plain_hdr_len()] = [0; plain_hdr::max_plain_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf[..], plain_hdr::max_plain_hdr_len()); - proto_tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.as_slice(); + let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + tx.plain.encode(&mut write_buf)?; + let plain_hdr_bytes = write_buf.into_slice(); - trace!("unencrypted packet: {:x?}", proto_tx.as_borrow_slice()); - let ctr = proto_tx.plain.ctr; + trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); + let ctr = tx.plain.ctr; let enc_key = self.get_enc_key(); if let Some(e) = enc_key { proto_hdr::encrypt_in_place( ctr, self.local_nodeid, plain_hdr_bytes, - proto_tx.get_writebuf()?, + tx.get_writebuf()?, e, )?; } - proto_tx.get_writebuf()?.prepend(plain_hdr_bytes)?; - trace!("Full encrypted packet: {:x?}", proto_tx.as_borrow_slice()); + tx.get_writebuf()?.prepend(plain_hdr_bytes)?; + trace!("Full encrypted packet: {:x?}", tx.as_mut_slice()); Ok(()) } + + fn rand_msg_ctr(rand: Rand) -> u32 { + let mut buf = [0; 4]; + rand(&mut buf); + u32::from_be_bytes(buf) & MATTER_MSG_CTR_RANGE + } } impl fmt::Display for Session { @@ -324,36 +322,23 @@ impl fmt::Display for Session { } pub const MAX_SESSIONS: usize = 16; + pub struct SessionMgr { next_sess_id: u16, sessions: [Option; MAX_SESSIONS], - network: Option>, -} - -impl Default for SessionMgr { - fn default() -> Self { - Self::new() - } + epoch: Epoch, + rand: Rand, } impl SessionMgr { - pub fn new() -> SessionMgr { - SessionMgr { - sessions: Default::default(), - next_sess_id: 1, - network: None, - } - } + pub fn new(epoch: Epoch, rand: Rand) -> Self { + const INIT: Option = None; - pub fn add_network_interface( - &mut self, - interface: Box, - ) -> Result<(), Error> { - if self.network.is_none() { - self.network = Some(interface); - Ok(()) - } else { - Err(Error::Invalid) + Self { + sessions: [INIT; MAX_SESSIONS], + next_sess_id: 1, + epoch, + rand, } } @@ -380,13 +365,21 @@ impl SessionMgr { next_sess_id } + pub fn get_session_for_eviction(&self) -> Option { + if self.get_empty_slot().is_none() { + Some(self.get_lru()) + } else { + None + } + } + fn get_empty_slot(&self) -> Option { self.sessions.iter().position(|x| x.is_none()) } - pub fn get_lru(&mut self) -> usize { + fn get_lru(&self) -> usize { let mut lru_index = 0; - let mut lru_ts = SystemTime::now(); + let mut lru_ts = (self.epoch)(); for i in 0..MAX_SESSIONS { if let Some(s) = &self.sessions[i] { if s.last_use < lru_ts { @@ -399,7 +392,7 @@ impl SessionMgr { } pub fn add(&mut self, peer_addr: Address, peer_nodeid: Option) -> Result { - let session = Session::new(peer_addr, peer_nodeid); + let session = Session::new(peer_addr, peer_nodeid, self.epoch, self.rand); self.add_session(session) } @@ -422,7 +415,7 @@ impl SessionMgr { } pub fn clone_session(&mut self, clone_data: &CloneData) -> Result { - let session = Session::clone(clone_data); + let session = Session::clone(clone_data, self.epoch, self.rand); self.add_session(session) } @@ -478,68 +471,50 @@ impl SessionMgr { // We will try to get a session for this Packet. If no session exists, we will try to add one // If the session list is full we will return a None - pub fn post_recv(&mut self, rx: &Packet) -> Result, Error> { - let sess_index = match self.get_or_add( + pub fn post_recv(&mut self, rx: &Packet) -> Result { + let sess_index = self.get_or_add( rx.plain.sess_id, rx.peer, rx.plain.get_src_u64(), rx.plain.is_encrypted(), - ) { - Ok(s) => { - let session = self.sessions[s].as_mut().unwrap(); - let is_encrypted = session.is_encrypted(); - let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); - if duplicate { - info!("Dropping duplicate packet"); - return Err(Error::Duplicate); - } else { - Some(s) - } - } - Err(Error::NoSpace) => None, - Err(e) => { - return Err(e); - } - }; - Ok(sess_index) + )?; + + let session = self.sessions[sess_index].as_mut().unwrap(); + let is_encrypted = session.is_encrypted(); + let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); + if duplicate { + info!("Dropping duplicate packet"); + Err(Error::Duplicate) + } else { + Ok(sess_index) + } } - pub fn recv(&mut self) -> Result<(BoxSlab, Option), Error> { - let mut rx = - Slab::::try_new(Packet::new_rx()?).ok_or(Error::PacketPoolExhaust)?; + pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { + // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + // let (len, src) = network.recv(rx.as_borrow_slice()).await?; + // rx.get_parsebuf()?.set_len(len); + // rx.peer = src; - let (len, src) = network.recv(rx.as_borrow_slice())?; - rx.get_parsebuf()?.set_len(len); - rx.peer = src; - - info!("{} from src: {}", "Received".blue(), src); - trace!("payload: {:x?}", rx.as_borrow_slice()); + // info!("{} from src: {}", "Received".blue(), src); + // trace!("payload: {:x?}", rx.as_borrow_slice()); // Read unencrypted packet header - rx.plain_hdr_decode()?; - - // Get session - let sess_handle = self.post_recv(&rx)?; - - Ok((rx, sess_handle)) + rx.plain_hdr_decode() } - pub fn send( - &mut self, - sess_idx: usize, - mut proto_tx: BoxSlab, - ) -> Result<(), Error> { + pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() .ok_or(Error::NoSession)? - .do_send(&mut proto_tx)?; + .do_send(self.epoch, tx)?; + + // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + // let peer = proto_tx.peer; + // network.send(proto_tx.as_borrow_slice(), peer).await?; + // info!("Message Sent to {}", peer); - let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - let peer = proto_tx.peer; - network.send(proto_tx.as_borrow_slice(), peer)?; - println!("Message Sent to {}", peer); Ok(()) } @@ -568,40 +543,52 @@ pub struct SessionHandle<'a> { } impl<'a> SessionHandle<'a> { + pub fn session(&self) -> &Session { + self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap() + } + pub fn reserve_new_sess_id(&mut self) -> u16 { self.sess_mgr.get_next_sess_id() } - pub fn send(&mut self, proto_tx: BoxSlab) -> Result<(), Error> { - self.sess_mgr.send(self.sess_idx, proto_tx) + pub fn send(&mut self, tx: &mut Packet) -> Result<(), Error> { + self.sess_mgr.send(self.sess_idx, tx) } } impl<'a> Deref for SessionHandle<'a> { type Target = Session; + fn deref(&self) -> &Self::Target { // There is no other option but to panic if this is None - self.sess_mgr.sessions[self.sess_idx].as_ref().unwrap() + self.session() } } impl<'a> DerefMut for SessionHandle<'a> { fn deref_mut(&mut self) -> &mut Self::Target { // There is no other option but to panic if this is None - self.sess_mgr.sessions[self.sess_idx].as_mut().unwrap() + self.session_mut() } } #[cfg(test)] mod tests { - use crate::transport::network::Address; + use crate::{ + transport::network::Address, + utils::{epoch::dummy_epoch, rand::dummy_rand}, + }; use super::SessionMgr; #[test] fn test_next_sess_id_doesnt_reuse() { - let mut sm = SessionMgr::new(); + let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); let mut sess = sm.get_session_handle(sess_idx); sess.set_local_sess_id(1); @@ -615,7 +602,7 @@ mod tests { #[test] fn test_next_sess_id_overflows() { - let mut sm = SessionMgr::new(); + let mut sm = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sm.add(Address::default(), None).unwrap(); let mut sess = sm.get_session_handle(sess_idx); sess.set_local_sess_id(1); diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 8abe9af..6f7a265 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -16,9 +16,10 @@ */ use crate::error::*; +use log::info; use smol::net::{Ipv6Addr, UdpSocket}; -use super::network::{Address, NetworkInterface}; +use super::network::Address; // We could get rid of the smol here, but keeping it around in case we have to process // any other events in this thread's context @@ -33,25 +34,26 @@ pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MATTER_PORT: u16 = 5540; impl UdpListener { - pub fn new() -> Result { + pub async fn new() -> Result { Ok(UdpListener { - socket: smol::block_on(UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)))?, + socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?, }) } -} -impl NetworkInterface for UdpListener { - fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - let (size, addr) = smol::block_on(self.socket.recv_from(in_buf)).map_err(|e| { - println!("Error on the network: {:?}", e); + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { + let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { + info!("Error on the network: {:?}", e); Error::Network })?; Ok((size, Address::Udp(addr))) } - fn send(&self, out_buf: &[u8], addr: Address) -> Result { + pub async fn send(&self, out_buf: &[u8], addr: Address) -> Result { match addr { - Address::Udp(addr) => Ok(smol::block_on(self.socket.send_to(out_buf, addr))?), + Address::Udp(addr) => self.socket.send_to(out_buf, addr).await.map_err(|e| { + info!("Error on the network: {:?}", e); + Error::Network + }), } } } diff --git a/matter/src/utils/epoch.rs b/matter/src/utils/epoch.rs new file mode 100644 index 0000000..999cdf3 --- /dev/null +++ b/matter/src/utils/epoch.rs @@ -0,0 +1,14 @@ +use core::time::Duration; + +pub type Epoch = fn() -> Duration; + +pub fn dummy_epoch() -> Duration { + Duration::from_secs(0) +} + +#[cfg(feature = "std")] +pub fn sys_epoch() -> Duration { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() +} diff --git a/matter/src/utils/mod.rs b/matter/src/utils/mod.rs index 9fc44a8..1e69b84 100644 --- a/matter/src/utils/mod.rs +++ b/matter/src/utils/mod.rs @@ -15,5 +15,7 @@ * limitations under the License. */ +pub mod epoch; pub mod parsebuf; +pub mod rand; pub mod writebuf; diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index b5342c0..d6a8b9a 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -25,11 +25,13 @@ pub struct ParseBuf<'a> { } impl<'a> ParseBuf<'a> { - pub fn new(buf: &'a mut [u8], len: usize) -> ParseBuf<'a> { - ParseBuf { - buf: &mut buf[..len], + pub fn new(buf: &'a mut [u8]) -> Self { + let left = buf.len(); + + Self { + buf, read_off: 0, - left: len, + left, } } @@ -38,12 +40,17 @@ impl<'a> ParseBuf<'a> { } // Return the data that is valid as a slice, consume self - pub fn as_slice(self) -> &'a mut [u8] { + pub fn into_slice(self) -> &'a mut [u8] { &mut self.buf[self.read_off..(self.read_off + self.left)] } // Return the data that is valid as a slice - pub fn as_borrow_slice(&mut self) -> &mut [u8] { + pub fn as_slice(&self) -> &[u8] { + &self.buf[self.read_off..(self.read_off + self.left)] + } + + // Return the data that is valid as a slice + pub fn as_mut_slice(&mut self) -> &mut [u8] { &mut self.buf[self.read_off..(self.read_off + self.left)] } @@ -101,19 +108,19 @@ mod tests { #[test] fn test_parse_with_success() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); - assert_eq!(buf.as_slice(), [0xa, 0xb, 0xc, 0xd]); + assert_eq!(buf.into_slice(), [0xa, 0xb, 0xc, 0xd]); } #[test] fn test_parse_with_overrun() { - let mut test_slice: [u8; 2] = [0x01, 65]; - let mut buf = ParseBuf::new(&mut test_slice, 2); + let mut test_slice = [0x01, 65]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); @@ -131,29 +138,29 @@ mod tests { if buf.le_u8().is_ok() { panic!("This should have returned error") } - assert_eq!(buf.as_slice(), []); + assert_eq!(buf.into_slice(), []); } #[test] fn test_tail_with_success() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); assert_eq!(buf.tail(2).unwrap(), [0xc, 0xd]); - assert_eq!(buf.as_borrow_slice(), [0xa, 0xb]); + assert_eq!(buf.as_mut_slice(), [0xa, 0xb]); assert_eq!(buf.tail(2).unwrap(), [0xa, 0xb]); - assert_eq!(buf.as_slice(), []); + assert_eq!(buf.into_slice(), []); } #[test] fn test_tail_with_overrun() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); @@ -166,8 +173,8 @@ mod tests { #[test] fn test_parsed_as_slice() { - let mut test_slice: [u8; 11] = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; - let mut buf = ParseBuf::new(&mut test_slice, 11); + let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; + let mut buf = ParseBuf::new(&mut test_slice); assert_eq!(buf.parsed_as_slice(), []); assert_eq!(buf.le_u8().unwrap(), 0x1); diff --git a/matter/src/utils/rand.rs b/matter/src/utils/rand.rs new file mode 100644 index 0000000..3cd698c --- /dev/null +++ b/matter/src/utils/rand.rs @@ -0,0 +1,3 @@ +pub type Rand = fn(&mut [u8]); + +pub fn dummy_rand(_buf: &mut [u8]) {} diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index cf28888..fae4481 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -53,9 +53,9 @@ pub struct WriteBuf<'a> { } impl<'a> WriteBuf<'a> { - pub fn new(buf: &'a mut [u8], len: usize) -> WriteBuf<'a> { - WriteBuf { - buf: &mut buf[..len], + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + buf, start: 0, end: 0, } @@ -73,11 +73,11 @@ impl<'a> WriteBuf<'a> { self.end += new_offset } - pub fn as_borrow_slice(&self) -> &[u8] { + pub fn into_slice(self) -> &'a [u8] { &self.buf[self.start..self.end] } - pub fn as_slice(self) -> &'a [u8] { + pub fn as_slice(&self) -> &[u8] { &self.buf[self.start..self.end] } @@ -201,9 +201,8 @@ mod tests { #[test] fn test_append_le_with_success() { - let mut test_slice: [u8; 22] = [0; 22]; - let test_slice_len = test_slice.len(); - let mut buf = WriteBuf::new(&mut test_slice, test_slice_len); + let mut test_slice = [0; 22]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u8(1).unwrap(); @@ -222,8 +221,8 @@ mod tests { #[test] fn test_len_param() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 5); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice[..5]); buf.reserve(5).unwrap(); let _ = buf.le_u8(1); @@ -236,8 +235,8 @@ mod tests { #[test] fn test_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(4).unwrap(); buf.le_u64(0xcafebabecafebabe).unwrap(); buf.le_u64(0xcafebabecafebabe).unwrap(); @@ -262,8 +261,8 @@ mod tests { #[test] fn test_as_slice() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u8(1).unwrap(); @@ -275,7 +274,7 @@ mod tests { buf.prepend(&new_slice).unwrap(); assert_eq!( - buf.as_slice(), + buf.into_slice(), [ 0xa, 0xb, 0xc, 1, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca, 0xbe, 0xba, 0xfe, 0xca @@ -285,8 +284,8 @@ mod tests { #[test] fn test_copy_as_slice() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -301,8 +300,8 @@ mod tests { #[test] fn test_copy_as_slice_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 7); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice[..7]); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -314,8 +313,8 @@ mod tests { #[test] fn test_prepend() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -329,8 +328,8 @@ mod tests { #[test] fn test_prepend_overrun() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -342,8 +341,8 @@ mod tests { #[test] fn test_rewind_tail() { - let mut test_slice: [u8; 20] = [0; 20]; - let mut buf = WriteBuf::new(&mut test_slice, 20); + let mut test_slice = [0; 20]; + let mut buf = WriteBuf::new(&mut test_slice); buf.reserve(5).unwrap(); buf.le_u16(65).unwrap(); @@ -352,13 +351,10 @@ mod tests { let new_slice: [u8; 5] = [0xaa, 0xbb, 0xcc, 0xdd, 0xee]; buf.copy_from_slice(&new_slice).unwrap(); - assert_eq!( - buf.as_borrow_slice(), - [65, 0, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,] - ); + assert_eq!(buf.as_slice(), [65, 0, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,]); buf.rewind_tail_to(anchor); buf.le_u16(66).unwrap(); - assert_eq!(buf.as_borrow_slice(), [65, 0, 66, 0,]); + assert_eq!(buf.as_slice(), [65, 0, 66, 0,]); } } diff --git a/matter/tests/common/attributes.rs b/matter/tests/common/attributes.rs index 1879adf..2ff95eb 100644 --- a/matter/tests/common/attributes.rs +++ b/matter/tests/common/attributes.rs @@ -115,8 +115,7 @@ impl TLVHolder { buf: [0; 100], used_len: 0, }; - let buf_len = s.buf.len(); - let mut wb = WriteBuf::new(&mut s.buf, buf_len); + let mut wb = WriteBuf::new(&mut s.buf); let mut tw = TLVWriter::new(&mut wb); let _ = tw.start_array(TagType::Context(ctx_tag)); for e in data { diff --git a/matter/tests/common/commands.rs b/matter/tests/common/commands.rs index 919565a..419b6ac 100644 --- a/matter/tests/common/commands.rs +++ b/matter/tests/common/commands.rs @@ -76,7 +76,7 @@ macro_rules! echo_req { CmdPath::new( Some($endpoint), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ), EncodeValue::Value(&($data as u32)), ) @@ -90,7 +90,7 @@ macro_rules! echo_resp { CmdPath::new( Some($endpoint), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoResp as u16), + Some(echo_cluster::RespCommands::EchoResp as u32), ), $data, ) diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index cf07183..dd61a0e 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -15,31 +15,91 @@ * limitations under the License. */ -use std::sync::{Arc, Mutex, Once}; +use std::{ + convert::TryInto, + sync::{Arc, Mutex, Once}, +}; use matter::{ + attribute_enum, command_enum, data_model::objects::{ - Access, AttrDetails, AttrValue, Attribute, Cluster, ClusterType, EncodeValue, Encoder, - Quality, + Access, AttrData, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, + Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, + ATTRIBUTE_LIST, FEATURE_MAP, }, error::Error, interaction_model::{ - command::CommandReq, - core::IMStatusCode, - messages::ib::{self, attr_list_write, ListOperation}, + core::Transaction, + messages::ib::{attr_list_write, ListOperation}, }, - tlv::{TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{TLVElement, TagType}, + utils::rand::Rand, }; use num_derive::FromPrimitive; +use strum::{EnumDiscriminants, FromRepr}; pub const ID: u32 = 0xABCD; -#[derive(FromPrimitive)] +#[derive(FromRepr, EnumDiscriminants)] +#[repr(u16)] +pub enum Attributes { + Att1(AttrType) = 0, + Att2(AttrType) = 1, + AttWrite(AttrType) = 2, + AttCustom(AttrType) = 3, + AttWriteList(()) = 4, +} + +attribute_enum!(Attributes); + +#[derive(FromRepr)] +#[repr(u32)] pub enum Commands { EchoReq = 0x00, +} + +command_enum!(Commands); + +#[derive(FromPrimitive)] +pub enum RespCommands { EchoResp = 0x01, } +pub const CLUSTER: Cluster<'static> = Cluster { + id: ID, + feature_map: 0, + attributes: &[ + FEATURE_MAP, + ATTRIBUTE_LIST, + Attribute::new( + AttributesDiscriminants::Att1 as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::Att2 as u16, + Access::RV, + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttWrite as u16, + Access::WRITE.union(Access::NEED_ADMIN), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttCustom as u16, + Access::READ.union(Access::NEED_VIEW), + Quality::NONE, + ), + Attribute::new( + AttributesDiscriminants::AttWriteList as u16, + Access::WRITE.union(Access::NEED_ADMIN), + Quality::NONE, + ), + ], + commands: &[Commands::EchoReq as _], +}; + /// This is used in the tests to validate any settings that may have happened /// to the custom data parts of the cluster pub struct TestChecker { @@ -68,167 +128,122 @@ impl TestChecker { } pub const WRITE_LIST_MAX: usize = 5; + pub struct EchoCluster { - pub base: Cluster, + pub data_ver: Dataver, pub multiplier: u8, -} - -#[derive(FromPrimitive)] -pub enum Attributes { - Att1 = 0, - Att2 = 1, - AttWrite = 2, - AttCustom = 3, - AttWriteList = 4, -} - -pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; -pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; - -impl ClusterType for EchoCluster { - fn base(&self) -> &Cluster { - &self.base - } - - fn base_mut(&mut self) -> &mut Cluster { - &mut self.base - } - - fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::AttCustom) => encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.u32(tag, ATTR_CUSTOM_VALUE); - })), - Some(Attributes::AttWriteList) => { - let tc_handle = TestChecker::get().unwrap(); - let tc = tc_handle.lock().unwrap(); - encoder.encode(EncodeValue::Closure(&|tag, tw| { - let _ = tw.start_array(tag); - for i in tc.write_list.iter().flatten() { - let _ = tw.u16(TagType::Anonymous, *i); - } - let _ = tw.end_container(); - })) - } - _ => (), - } - } - - fn write_attribute( - &mut self, - attr: &AttrDetails, - data: &TLVElement, - ) -> Result<(), IMStatusCode> { - match num::FromPrimitive::from_u16(attr.attr_id) { - Some(Attributes::AttWriteList) => { - attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data)) - } - _ => self.base.write_attribute_from_tlv(attr.attr_id, data), - } - } - - fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { - let cmd = cmd_req - .cmd - .path - .leaf - .map(num::FromPrimitive::from_u32) - .ok_or(IMStatusCode::UnsupportedCommand)? - .ok_or(IMStatusCode::UnsupportedCommand)?; - match cmd { - // This will generate an echo response on the same endpoint - // with data multiplied by the multiplier - Commands::EchoReq => { - let a = cmd_req.data.u8().unwrap(); - let mut echo_response = cmd_req.cmd; - echo_response.path.leaf = Some(Commands::EchoResp as u32); - - let cmd_data = |tag: TagType, t: &mut TLVWriter| { - let _ = t.start_struct(tag); - // Echo = input * self.multiplier - let _ = t.u8(TagType::Context(0), a * self.multiplier); - let _ = t.end_container(); - }; - - let invoke_resp = ib::InvResp::Cmd(ib::CmdData::new( - echo_response, - EncodeValue::Closure(&cmd_data), - )); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); - cmd_req.trans.complete(); - } - _ => { - return Err(IMStatusCode::UnsupportedCommand); - } - } - Ok(()) - } + pub att1: u16, + pub att2: u16, + pub att_write: u16, + pub att_custom: u32, } impl EchoCluster { - pub fn new(multiplier: u8) -> Result, Error> { - let mut c = Box::new(Self { - base: Cluster::new(ID)?, + pub fn new(multiplier: u8, rand: Rand) -> Self { + Self { + data_ver: Dataver::new(rand), multiplier, - }); - c.base.add_attribute(Attribute::new( - Attributes::Att1 as u16, - AttrValue::Uint16(0x1234), - Access::RV, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::Att2 as u16, - AttrValue::Uint16(0x5678), - Access::RV, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttWrite as u16, - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttCustom as u16, - AttrValue::Custom, - Access::READ | Access::NEED_VIEW, - Quality::NONE, - ))?; - c.base.add_attribute(Attribute::new( - Attributes::AttWriteList as u16, - AttrValue::Custom, - Access::WRITE | Access::NEED_ADMIN, - Quality::NONE, - ))?; - Ok(c) + att1: 0x1234, + att2: 0x5678, + att_write: ATTR_WRITE_DEFAULT_VALUE, + att_custom: ATTR_CUSTOM_VALUE, + } } - fn write_attr_list( + pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? { + if attr.is_system() { + CLUSTER.read(attr.attr_id, writer) + } else { + match attr.attr_id.try_into()? { + Attributes::Att1(codec) => codec.encode(writer, 0x1234), + Attributes::Att2(codec) => codec.encode(writer, 0x5678), + Attributes::AttWrite(codec) => codec.encode(writer, ATTR_WRITE_DEFAULT_VALUE), + Attributes::AttCustom(codec) => codec.encode(writer, ATTR_CUSTOM_VALUE), + Attributes::AttWriteList(_) => { + let tc_handle = TestChecker::get().unwrap(); + let tc = tc_handle.lock().unwrap(); + + writer.start_array(AttrDataWriter::TAG)?; + for i in tc.write_list.iter().flatten() { + writer.u16(TagType::Anonymous, *i)?; + } + writer.end_container()?; + + writer.complete() + } + } + } + } else { + Ok(()) + } + } + + pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + let data = data.with_dataver(self.data_ver.get())?; + + match attr.attr_id.try_into()? { + Attributes::Att1(codec) => self.att1 = codec.decode(data)?, + Attributes::Att2(codec) => self.att2 = codec.decode(data)?, + Attributes::AttWrite(codec) => self.att_write = codec.decode(data)?, + Attributes::AttCustom(codec) => self.att_custom = codec.decode(data)?, + Attributes::AttWriteList(_) => { + attr_list_write(attr, data, |op, data| self.write_attr_list(&op, data))? + } + } + + self.data_ver.changed(); + + Ok(()) + } + + pub fn invoke( &mut self, - op: &ListOperation, + _transaction: &mut Transaction, + cmd: &CmdDetails, data: &TLVElement, - ) -> Result<(), IMStatusCode> { + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + match cmd.cmd_id.try_into()? { + // This will generate an echo response on the same endpoint + // with data multiplied by the multiplier + Commands::EchoReq => { + let a = data.u8()?; + + let mut writer = encoder.with_command(RespCommands::EchoResp as _)?; + + writer.start_struct(CmdDataWriter::TAG)?; + // Echo = input * self.multiplier + writer.u8(TagType::Context(0), a * self.multiplier)?; + writer.end_container()?; + + writer.complete() + } + } + } + + fn write_attr_list(&mut self, op: &ListOperation, data: &TLVElement) -> Result<(), Error> { let tc_handle = TestChecker::get().unwrap(); let mut tc = tc_handle.lock().unwrap(); match op { ListOperation::AddItem => { - let data = data.u16().map_err(|_| IMStatusCode::Failure)?; + let data = data.u16()?; for i in 0..WRITE_LIST_MAX { if tc.write_list[i].is_none() { tc.write_list[i] = Some(data); return Ok(()); } } - Err(IMStatusCode::ResourceExhausted) + + Err(Error::ResourceExhausted) } ListOperation::EditItem(index) => { - let data = data.u16().map_err(|_| IMStatusCode::Failure)?; + let data = data.u16()?; if tc.write_list[*index as usize].is_some() { tc.write_list[*index as usize] = Some(data); Ok(()) } else { - Err(IMStatusCode::InvalidAction) + Err(Error::InvalidAction) } } ListOperation::DeleteItem(index) => { @@ -236,7 +251,7 @@ impl EchoCluster { tc.write_list[*index as usize] = None; Ok(()) } else { - Err(IMStatusCode::InvalidAction) + Err(Error::InvalidAction) } } ListOperation::DeleteList => { @@ -248,3 +263,26 @@ impl EchoCluster { } } } + +pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; +pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; + +impl Handler for EchoCluster { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + EchoCluster::read(self, attr, encoder) + } + + fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + EchoCluster::write(self, attr, data) + } + + fn invoke( + &mut self, + transaction: &mut Transaction, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + EchoCluster::invoke(self, transaction, cmd, data, encoder) + } +} diff --git a/matter/tests/common/handlers.rs b/matter/tests/common/handlers.rs new file mode 100644 index 0000000..7235b8a --- /dev/null +++ b/matter/tests/common/handlers.rs @@ -0,0 +1,317 @@ +use core::time; +use std::thread; + +use log::{info, warn}; +use matter::{ + interaction_model::{ + core::{IMStatusCode, OpCode}, + messages::{ + ib::{AttrData, AttrPath, AttrResp, AttrStatus, CmdData, DataVersionFilter}, + msg::{ + self, InvReq, ReadReq, ReportDataMsg, StatusResp, TimedReq, WriteReq, WriteResp, + WriteRespTag, + }, + }, + }, + tlv::{self, FromTLV, TLVArray, ToTLV}, + transport::{ + exchange::{self, Exchange}, + session::NocCatIds, + }, + Matter, +}; + +use super::{ + attributes::assert_attr_report, + commands::{assert_inv_response, ExpectedInvResp}, + im_engine::{ImEngine, ImInput, IM_ENGINE_PEER_ID}, +}; + +pub enum WriteResponse<'a> { + TransactionError, + TransactionSuccess(&'a [AttrStatus]), +} + +pub enum TimedInvResponse<'a> { + TransactionError(IMStatusCode), + TransactionSuccess(&'a [ExpectedInvResp]), +} + +impl<'a> ImEngine<'a> { + // Helper for handling Read Req sequences for this file + pub fn handle_read_reqs( + &mut self, + peer_node_id: u64, + input: &[AttrPath], + expected: &[AttrResp], + ) { + let mut out_buf = [0u8; 400]; + let received = self.gen_read_reqs_output(peer_node_id, input, None, &mut out_buf); + assert_attr_report(&received, expected) + } + + pub fn new_with_read_reqs( + matter: &'a Matter<'a>, + input: &[AttrPath], + expected: &[AttrResp], + ) -> Self { + let mut im = Self::new(matter); + + let mut out_buf = [0u8; 400]; + let received = im.gen_read_reqs_output(IM_ENGINE_PEER_ID, input, None, &mut out_buf); + assert_attr_report(&received, expected); + + im + } + + pub fn gen_read_reqs_output<'b>( + &mut self, + peer_node_id: u64, + input: &[AttrPath], + dataver_filters: Option>, + out_buf: &'b mut [u8], + ) -> ReportDataMsg<'b> { + let mut read_req = ReadReq::new(true).set_attr_requests(input); + read_req.dataver_filters = dataver_filters; + + let mut input = ImInput::new(OpCode::ReadRequest, &read_req); + input.set_peer_node_id(peer_node_id); + + let (_, out_buf) = self.process(&input, out_buf); + + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + ReportDataMsg::from_tlv(&root).unwrap() + } + + pub fn handle_write_reqs( + &mut self, + peer_node_id: u64, + peer_cat_ids: Option<&NocCatIds>, + input: &[AttrData], + expected: &[AttrStatus], + ) { + let mut out_buf = [0u8; 400]; + let write_req = WriteReq::new(false, input); + + let mut input = ImInput::new(OpCode::WriteRequest, &write_req); + input.set_peer_node_id(peer_node_id); + if let Some(cat_ids) = peer_cat_ids { + input.set_cat_ids(cat_ids); + } + + let (_, out_buf) = self.process(&input, &mut out_buf); + + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + let mut index = 0; + let response_iter = root + .find_tag(WriteRespTag::WriteResponses as u32) + .unwrap() + .confirm_array() + .unwrap() + .enter() + .unwrap(); + + for response in response_iter { + info!("Validating index {}", index); + let status = AttrStatus::from_tlv(&response).unwrap(); + assert_eq!(expected[index], status); + info!("Index {} success", index); + index += 1; + } + assert_eq!(index, expected.len()); + } + + pub fn new_with_write_reqs( + matter: &'a Matter<'a>, + input: &[AttrData], + expected: &[AttrStatus], + ) -> Self { + let mut im = Self::new(matter); + + im.handle_write_reqs(IM_ENGINE_PEER_ID, None, input, expected); + + im + } + + // Helper for handling Invoke Command sequences + pub fn handle_commands( + &mut self, + peer_node_id: u64, + input: &[CmdData], + expected: &[ExpectedInvResp], + ) { + let mut out_buf = [0u8; 400]; + let req = InvReq { + suppress_response: Some(false), + timed_request: Some(false), + inv_requests: Some(TLVArray::Slice(input)), + }; + + let mut input = ImInput::new(OpCode::InvokeRequest, &req); + input.set_peer_node_id(peer_node_id); + + let (_, out_buf) = self.process(&input, &mut out_buf); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + let resp = msg::InvResp::from_tlv(&root).unwrap(); + assert_inv_response(&resp, expected) + } + + pub fn new_with_commands( + matter: &'a Matter<'a>, + input: &[CmdData], + expected: &[ExpectedInvResp], + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_commands(IM_ENGINE_PEER_ID, input, expected); + + im + } + + fn handle_timed_reqs<'b>( + &mut self, + opcode: OpCode, + request: &dyn ToTLV, + timeout: u16, + delay: u16, + output: &'b mut [u8], + ) -> (u8, &'b [u8]) { + // Use the same exchange for all parts of the transaction + self.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); + + if timeout != 0 { + // Send Timed Req + let mut tmp_buf = [0u8; 400]; + let timed_req = TimedReq { timeout }; + let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); + let (_, out_buf) = self.process(&im_input, &mut tmp_buf); + tlv::print_tlv_list(out_buf); + } else { + warn!("Skipping timed request"); + } + + // Process any delays + let delay = time::Duration::from_millis(delay.into()); + thread::sleep(delay); + + // Send Write Req + let input = ImInput::new(opcode, request); + let (resp_opcode, output) = self.process(&input, output); + (resp_opcode, output) + } + + // Helper for handling Write Attribute sequences + pub fn handle_timed_write_reqs( + &mut self, + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) { + let mut out_buf = [0u8; 400]; + let write_req = WriteReq::new(false, input); + + let (resp_opcode, out_buf) = self.handle_timed_reqs( + OpCode::WriteRequest, + &write_req, + timeout, + delay, + &mut out_buf, + ); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + match expected { + WriteResponse::TransactionSuccess(t) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::WriteResponse) + ); + let resp = WriteResp::from_tlv(&root).unwrap(); + assert_eq!(resp.write_responses, t); + } + WriteResponse::TransactionError => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::StatusResponse) + ); + let status_resp = StatusResp::from_tlv(&root).unwrap(); + assert_eq!(status_resp.status, IMStatusCode::Timeout); + } + } + } + + pub fn new_with_timed_write_reqs( + matter: &'a Matter<'a>, + input: &[AttrData], + expected: &WriteResponse, + timeout: u16, + delay: u16, + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_timed_write_reqs(input, expected, timeout, delay); + + im + } + + // Helper for handling Invoke Command sequences + pub fn handle_timed_commands( + &mut self, + input: &[CmdData], + expected: &TimedInvResponse, + timeout: u16, + delay: u16, + set_timed_request: bool, + ) { + let mut out_buf = [0u8; 400]; + let req = InvReq { + suppress_response: Some(false), + timed_request: Some(set_timed_request), + inv_requests: Some(TLVArray::Slice(input)), + }; + + let (resp_opcode, out_buf) = + self.handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); + tlv::print_tlv_list(out_buf); + let root = tlv::get_root_node_struct(out_buf).unwrap(); + + match expected { + TimedInvResponse::TransactionSuccess(t) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::InvokeResponse) + ); + let resp = msg::InvResp::from_tlv(&root).unwrap(); + assert_inv_response(&resp, t) + } + TimedInvResponse::TransactionError(e) => { + assert_eq!( + num::FromPrimitive::from_u8(resp_opcode), + Some(OpCode::StatusResponse) + ); + let status_resp = StatusResp::from_tlv(&root).unwrap(); + assert_eq!(status_resp.status, *e); + } + } + } + + pub fn new_with_timed_commands( + matter: &'a Matter<'a>, + input: &[CmdData], + expected: &TimedInvResponse, + timeout: u16, + delay: u16, + set_timed_request: bool, + ) -> Self { + let mut im = ImEngine::new(matter); + + im.handle_timed_commands(input, expected, timeout, delay, set_timed_request); + + im + } +} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index f91433c..348ce74 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -16,54 +16,60 @@ */ use crate::common::echo_cluster; -use boxslab::Slab; +use core::borrow::Borrow; use matter::{ - acl::{AclEntry, AclMgr, AuthMode}, + acl::{AclEntry, AuthMode}, data_model::{ - cluster_basic_information::BasicInfoConfig, + cluster_basic_information::{self, BasicInfoConfig}, + cluster_on_off::{self, OnOffCluster}, core::DataModel, - device_types::device_type_add_on_off_light, - objects::Privilege, - sdm::dev_att::{DataType, DevAttDataFetcher}, + device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}, + objects::{ChainedHandler, Endpoint, Node, Privilege}, + root_endpoint::{self, RootEndpointHandler}, + sdm::{ + admin_commissioning, + dev_att::{DataType, DevAttDataFetcher}, + general_commissioning, noc, nw_commissioning, + }, + system_model::access_control, }, error::Error, - fabric::FabricMgr, - interaction_model::{core::OpCode, InteractionModel}, - secure_channel::pake::PaseMgr, + interaction_model::core::{InteractionModel, OpCode}, + mdns::Mdns, tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, transport::{ exchange::{self, Exchange, ExchangeCtx}, network::Address, - packet::PacketPool, - proto_demux::ProtoCtx, - session::{CloneData, NocCatIds, SessionMgr, SessionMode}, + proto_ctx::ProtoCtx, + session::{CaseDetails, CloneData, NocCatIds, SessionMgr, SessionMode}, }, - transport::{proto_demux::HandleProto, session::CaseDetails}, - utils::writebuf::WriteBuf, + utils::{epoch::sys_epoch, rand::dummy_rand, writebuf::WriteBuf}, + Matter, }; -use std::{ - net::{Ipv4Addr, SocketAddr}, - sync::Arc, +use std::net::{Ipv4Addr, SocketAddr}; + +use super::echo_cluster::EchoCluster; + +const BASIC_INFO: BasicInfoConfig<'static> = BasicInfoConfig { + vid: 10, + pid: 11, + hw_ver: 12, + sw_ver: 13, + sw_ver_str: "13", + serial_no: "aabbccdd", + device_name: "Test Device", }; pub struct DummyDevAtt {} + impl DevAttDataFetcher for DummyDevAtt { fn get_devatt_data(&self, _data_type: DataType, _data: &mut [u8]) -> Result { Ok(2) } } -/// An Interaction Model Engine to facilitate easy testing -pub struct ImEngine { - pub dm: DataModel, - pub acl_mgr: Arc, - pub im: Box, - // By default, a new exchange is created for every run, if you wish to instead using a specific - // exchange, set this variable. This is helpful in situations where you have to run multiple - // actions in the same transaction (exchange) - pub exch: Option, -} +pub const IM_ENGINE_PEER_ID: u64 = 445566; pub struct ImInput<'a> { action: OpCode, @@ -72,7 +78,6 @@ pub struct ImInput<'a> { cat_ids: NocCatIds, } -pub const IM_ENGINE_PEER_ID: u64 = 445566; impl<'a> ImInput<'a> { pub fn new(action: OpCode, data: &'a dyn ToTLV) -> Self { Self { @@ -92,56 +97,86 @@ impl<'a> ImInput<'a> { } } -impl ImEngine { - /// Create the interaction model engine - pub fn new() -> Self { - let dev_det = BasicInfoConfig { - vid: 10, - pid: 11, - hw_ver: 12, - sw_ver: 13, - sw_ver_str: "13".to_string(), - serial_no: "aabbccdd".to_string(), - device_name: "Test Device".to_string(), - }; +pub type DmHandler<'a> = ChainedHandler< + OnOffCluster, + ChainedHandler>>, +>; - let dev_att = Box::new(DummyDevAtt {}); - let fabric_mgr = Arc::new(FabricMgr::new().unwrap()); - let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); - let pase_mgr = PaseMgr::new(); - acl_mgr.erase_all(); +pub fn matter<'a>(mdns: &'a mut dyn Mdns) -> Matter<'_> { + Matter::new(&BASIC_INFO, mdns, sys_epoch, dummy_rand) +} + +/// An Interaction Model Engine to facilitate easy testing +pub struct ImEngine<'a> { + pub matter: &'a Matter<'a>, + pub im: InteractionModel>>, + // By default, a new exchange is created for every run, if you wish to instead using a specific + // exchange, set this variable. This is helpful in situations where you have to run multiple + // actions in the same transaction (exchange) + pub exch: Option, +} + +impl<'a> ImEngine<'a> { + /// Create the interaction model engine + pub fn new(matter: &'a Matter<'a>) -> Self { let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); // Only allow the standard peer node id of the IM Engine default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); - acl_mgr.add(default_acl).unwrap(); - let dm = DataModel::new(dev_det, dev_att, fabric_mgr, acl_mgr.clone(), pase_mgr).unwrap(); + matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); - { - let mut d = dm.node.write().unwrap(); - let light_endpoint = device_type_add_on_off_light(&mut d).unwrap(); - d.add_cluster(0, echo_cluster::EchoCluster::new(2).unwrap()) - .unwrap(); - d.add_cluster(light_endpoint, echo_cluster::EchoCluster::new(3).unwrap()) - .unwrap(); - } - - let im = Box::new(InteractionModel::new(Box::new(dm.clone()))); + let dm = DataModel::new( + matter.borrow(), + &Node { + id: 0, + endpoints: &[ + Endpoint { + id: 0, + clusters: &[ + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ROOT_NODE, + }, + Endpoint { + id: 1, + clusters: &[echo_cluster::CLUSTER, cluster_on_off::CLUSTER], + device_type: DEV_TYPE_ON_OFF_LIGHT, + }, + ], + }, + root_endpoint::handler(0, &DummyDevAtt {}, matter) + .chain(0, echo_cluster::ID, EchoCluster::new(2, *matter.borrow())) + .chain(1, echo_cluster::ID, EchoCluster::new(3, *matter.borrow())) + .chain(1, cluster_on_off::ID, OnOffCluster::new(*matter.borrow())), + ); Self { - dm, - acl_mgr, - im, + matter, + im: InteractionModel(dm), exch: None, } } + pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { + match endpoint { + 0 => &self.im.0.handler.next.next.handler, + 1 => &self.im.0.handler.next.handler, + _ => panic!(), + } + } + /// Run a transaction through the interaction model engine - pub fn process<'a>(&mut self, input: &ImInput, data_out: &'a mut [u8]) -> (u8, &'a mut [u8]) { + pub fn process<'b>(&mut self, input: &ImInput, data_out: &'b mut [u8]) -> (u8, &'b [u8]) { let mut new_exch = Exchange::new(1, 0, exchange::Role::Responder); // Choose whether to use a new exchange, or use the one from the ImEngine configuration let exch = self.exch.as_mut().unwrap_or(&mut new_exch); - let mut sess_mgr: SessionMgr = Default::default(); + let mut sess_mgr = SessionMgr::new(*self.matter.borrow(), *self.matter.borrow()); let clone_data = CloneData::new( 123456, @@ -156,9 +191,15 @@ impl ImEngine { ); let sess_idx = sess_mgr.clone_session(&clone_data).unwrap(); let sess = sess_mgr.get_session_handle(sess_idx); - let exch_ctx = ExchangeCtx { exch, sess }; - let mut rx = Slab::::try_new(Packet::new_rx().unwrap()).unwrap(); - let tx = Slab::::try_new(Packet::new_tx().unwrap()).unwrap(); + let exch_ctx = ExchangeCtx { + exch, + sess, + epoch: *self.matter.borrow(), + }; + let mut tx_buf = [0; 1500]; + let mut rx_buf = [0; 1500]; + let mut rx = Packet::new_rx(&mut rx_buf); + let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet rx.set_proto_id(0x01); rx.set_proto_opcode(input.action as u8); @@ -166,36 +207,37 @@ impl ImEngine { { let mut buf = [0u8; 400]; - let buf_len = buf.len(); - let mut wb = WriteBuf::new(&mut buf, buf_len); + let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); input.data.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - let input_data = wb.as_borrow_slice(); + let input_data = wb.as_slice(); let in_data_len = input_data.len(); - let rx_buf = rx.as_borrow_slice(); + let rx_buf = rx.as_mut_slice(); rx_buf[..in_data_len].copy_from_slice(input_data); rx.get_parsebuf().unwrap().set_len(in_data_len); } - let mut ctx = ProtoCtx::new(exch_ctx, rx, tx); - self.im.handle_proto_id(&mut ctx).unwrap(); - let out_data_len = ctx.tx.as_borrow_slice().len(); - data_out[..out_data_len].copy_from_slice(ctx.tx.as_borrow_slice()); + let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); + self.im.handle(&mut ctx).unwrap(); + let out_data_len = ctx.tx.as_slice().len(); + data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); let response = ctx.tx.get_proto_opcode(); - (response, &mut data_out[..out_data_len]) + (response, &data_out[..out_data_len]) } } -// Create an Interaction Model, Data Model and run a rx/tx transaction through it -pub fn im_engine<'a>( - action: OpCode, - data: &dyn ToTLV, - data_out: &'a mut [u8], -) -> (DataModel, u8, &'a mut [u8]) { - let mut engine = ImEngine::new(); - let input = ImInput::new(action, data); - let (response, output) = engine.process(&input, data_out); - (engine.dm, response, output) -} +// TODO - Remove? +// // Create an Interaction Model, Data Model and run a rx/tx transaction through it +// pub fn im_engine<'a>( +// matter: &'a Matter, +// action: OpCode, +// data: &dyn ToTLV, +// data_out: &'a mut [u8], +// ) -> (DmHandler<'a>, u8, &'a mut [u8]) { +// let mut engine = ImEngine::new(matter); +// let input = ImInput::new(action, data); +// let (response, output) = engine.process(&input, data_out); +// (engine.dm.handler, response, output) +// } diff --git a/matter/tests/common/mod.rs b/matter/tests/common/mod.rs index dea136a..0d2cc9c 100644 --- a/matter/tests/common/mod.rs +++ b/matter/tests/common/mod.rs @@ -18,4 +18,5 @@ pub mod attributes; pub mod commands; pub mod echo_cluster; +pub mod handlers; pub mod im_engine; diff --git a/matter/tests/data_model/acl_and_dataver.rs b/matter/tests/data_model/acl_and_dataver.rs index 493a282..535555b 100644 --- a/matter/tests/data_model/acl_and_dataver.rs +++ b/matter/tests/data_model/acl_and_dataver.rs @@ -18,19 +18,16 @@ use matter::{ acl::{gen_noc_cat, AclEntry, AuthMode, Target}, data_model::{ - objects::{AttrValue, EncodeValue, Privilege}, + objects::{EncodeValue, Privilege}, system_model::access_control, }, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, - msg::{ReadReq, ReportDataMsg, WriteReq}, - }, - messages::{msg, GenericPath}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, + messages::GenericPath, }, - tlv::{self, ElementType, FromTLV, TLVArray, TLVElement, TLVWriter, TagType}, - transport::session::NocCatIds, + mdns::DummyMdns, + tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, }; use crate::{ @@ -38,81 +35,10 @@ use crate::{ common::{ attributes::*, echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, - im_engine::{ImEngine, ImInput}, + im_engine::{matter, ImEngine}, }, }; -// Helper for handling Read Req sequences for this file -fn handle_read_reqs( - im: &mut ImEngine, - peer_node_id: u64, - input: &[AttrPath], - expected: &[AttrResp], -) { - let mut out_buf = [0u8; 400]; - let received = gen_read_reqs_output(im, peer_node_id, input, None, &mut out_buf); - assert_attr_report(&received, expected) -} - -fn gen_read_reqs_output<'a>( - im: &mut ImEngine, - peer_node_id: u64, - input: &[AttrPath], - dataver_filters: Option>, - out_buf: &'a mut [u8], -) -> ReportDataMsg<'a> { - let mut read_req = ReadReq::new(true).set_attr_requests(input); - read_req.dataver_filters = dataver_filters; - - let mut input = ImInput::new(OpCode::ReadRequest, &read_req); - input.set_peer_node_id(peer_node_id); - - let (_, out_buf) = im.process(&input, out_buf); - - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - ReportDataMsg::from_tlv(&root).unwrap() -} - -// Helper for handling Write Attribute sequences -fn handle_write_reqs( - im: &mut ImEngine, - peer_node_id: u64, - peer_cat_ids: Option<&NocCatIds>, - input: &[AttrData], - expected: &[AttrStatus], -) { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let mut input = ImInput::new(OpCode::WriteRequest, &write_req); - input.set_peer_node_id(peer_node_id); - if let Some(cat_ids) = peer_cat_ids { - input.set_cat_ids(cat_ids); - } - let (_, out_buf) = im.process(&input, &mut out_buf); - - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - let mut index = 0; - let response_iter = root - .find_tag(msg::WriteRespTag::WriteResponses as u32) - .unwrap() - .confirm_array() - .unwrap() - .enter() - .unwrap(); - for response in response_iter { - println!("Validating index {}", index); - let status = AttrStatus::from_tlv(&response).unwrap(); - assert_eq!(expected[index], status); - println!("Index {} success", index); - index += 1; - } - assert_eq!(index, expected.len()); -} - #[test] /// Ensure that wildcard read attributes don't include error response /// and silently drop the data when access is not granted @@ -122,43 +48,45 @@ fn wc_read_attribute() { let wc_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep1_att1 = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test1: Empty Response as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); // Add ACL to allow our peer to only access endpoint 0 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); - // Add ACL to allow our peer to only access endpoint 1 + // Add ACL to allow our peer to also access endpoint 1 let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test3: Both responses are valid let input = &[AttrPath::new(&wc_att1)]; @@ -166,7 +94,7 @@ fn wc_read_attribute() { attr_data_path!(ep0_att1, ElementType::U16(0x1234)), attr_data_path!(ep1_att1, ElementType::U16(0x1234)), ]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); } #[test] @@ -178,48 +106,33 @@ fn exact_read_attribute() { let wc_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test1: Unsupported Access error as no ACL matches let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; - handle_read_reqs(&mut im, peer, input, expected); + im.handle_read_reqs(peer, input, expected); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - handle_read_reqs(&mut im, peer, input, expected); -} - -fn read_cluster_id_write_attr(im: &ImEngine, endpoint: u16) -> AttrValue { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(endpoint, echo_cluster::ID).unwrap(); - - echo.base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - .clone() -} - -fn read_cluster_id_data_ver(im: &ImEngine, endpoint: u16) -> u32 { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(endpoint, echo_cluster::ID).unwrap(); - - echo.base().get_dataver() + im.handle_read_reqs(peer, input, expected); } #[test] @@ -239,17 +152,17 @@ fn wc_write_attribute() { let wc_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = &[AttrData::new( @@ -264,54 +177,41 @@ fn wc_write_attribute() { )]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Wildcard write to an attribute without permission should return // no error - handle_write_reqs(&mut im, peer, None, input0, &[]); - { - let node = im.dm.node.read().unwrap(); - let echo = node.get_cluster(0, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); - } + im.handle_write_reqs(peer, None, input0, &[]); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access one endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 2: Wildcard write to attributes will only return attributes // where the writes were successful - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input0, &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 1) - ); + assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(1).att_write); // Add ACL to allow our peer to access another endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 3: Wildcard write to attributes will return multiple attributes // where the writes were successful - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input1, @@ -320,8 +220,8 @@ fn wc_write_attribute() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ], ); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 0)); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 1)); + assert_eq!(val1, im.echo_cluster(0).att_write); + assert_eq!(val1, im.echo_cluster(1).att_write); } #[test] @@ -337,7 +237,7 @@ fn exact_write_attribute() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( @@ -353,25 +253,24 @@ fn exact_write_attribute() { let expected_success = &[AttrStatus::new(&ep0_att, IMStatusCode::Success, 0)]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - handle_write_reqs(&mut im, peer, None, input, expected_fail); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + im.handle_write_reqs(peer, None, input, expected_fail); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(peer).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - handle_write_reqs(&mut im, peer, None, input, expected_success); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + im.handle_write_reqs(peer, None, input, expected_success); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -388,7 +287,7 @@ fn exact_write_attribute_noc_cat() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( @@ -408,25 +307,24 @@ fn exact_write_attribute_noc_cat() { let noc_cat = gen_noc_cat(0xABCD, 2); let cat_in_acl = gen_noc_cat(0xABCD, 1); let cat_ids = [noc_cat, 0, 0]; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Test 1: Exact write to an attribute without permission should return // Unsupported Access Error - handle_write_reqs(&mut im, peer, Some(&cat_ids), input, expected_fail); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + im.handle_write_reqs(peer, Some(&cat_ids), input, expected_fail); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); // Add ACL to allow our peer to access any endpoint let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); acl.add_subject_catid(cat_in_acl).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test 1: Exact write to an attribute with permission should grant // access - handle_write_reqs(&mut im, peer, Some(&cat_ids), input, expected_success); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + im.handle_write_reqs(peer, Some(&cat_ids), input, expected_success); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -440,7 +338,7 @@ fn insufficient_perms_write() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = &[AttrData::new( None, @@ -449,17 +347,18 @@ fn insufficient_perms_write() { )]; let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); acl.add_subject(peer).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test: Not enough permission should return error - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input0, @@ -469,10 +368,7 @@ fn insufficient_perms_write() { 0, )], ); - assert_eq!( - AttrValue::Uint16(ATTR_WRITE_DEFAULT_VALUE), - read_cluster_id_write_attr(&im, 0) - ); + assert_eq!(ATTR_WRITE_DEFAULT_VALUE, im.echo_cluster(0).att_write); } #[test] @@ -485,7 +381,9 @@ fn insufficient_perms_write() { fn write_with_runtime_acl_add() { let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); let val0 = 10; let attr_data0 = |tag, t: &mut TLVWriter| { @@ -494,7 +392,7 @@ fn write_with_runtime_acl_add() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input0 = AttrData::new( None, @@ -509,7 +407,7 @@ fn write_with_runtime_acl_add() { let acl_att = GenericPath::new( Some(0), Some(access_control::ID), - Some(access_control::Attributes::Acl as u32), + Some(access_control::AttributesDiscriminants::Acl as u32), ); let acl_input = AttrData::new( None, @@ -523,11 +421,10 @@ fn write_with_runtime_acl_add() { basic_acl .add_target(Target::new(Some(0), Some(access_control::ID), None)) .unwrap(); - im.acl_mgr.add(basic_acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(basic_acl).unwrap(); // Test: deny write (with error), then ACL is added, then allow write - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, // write to echo-cluster attribute, write to acl attribute, write to echo-cluster attribute @@ -538,7 +435,7 @@ fn write_with_runtime_acl_add() { AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), ], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -551,16 +448,18 @@ fn test_read_data_ver() { // - 2 responses are expected let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&wc_ep_att1)]; @@ -569,7 +468,7 @@ fn test_read_data_ver() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) ), @@ -577,7 +476,7 @@ fn test_read_data_ver() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) ), @@ -585,7 +484,7 @@ fn test_read_data_ver() { let mut out_buf = [0u8; 400]; // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 - let received = gen_read_reqs_output(&mut im, peer, input, None, &mut out_buf); + let received = im.gen_read_reqs_output(peer, input, None, &mut out_buf); assert_attr_report(&received, expected); let data_ver_cluster_at_0 = received @@ -607,8 +506,7 @@ fn test_read_data_ver() { }]; // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved - let received = gen_read_reqs_output( - &mut im, + let received = im.gen_read_reqs_output( peer, input, Some(TLVArray::Slice(&dataver_filter)), @@ -618,7 +516,7 @@ fn test_read_data_ver() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32) + Some(echo_cluster::AttributesDiscriminants::Att1 as u32) ), ElementType::U16(0x1234) )]; @@ -629,11 +527,10 @@ fn test_read_data_ver() { let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&ep0_att1)]; - let received = gen_read_reqs_output( - &mut im, + let received = im.gen_read_reqs_output( peer, input, Some(TLVArray::Slice(&dataver_filter)), @@ -654,21 +551,23 @@ fn test_write_data_ver() { // - 2 responses are expected let _ = env_logger::try_init(); let peer = 98765; - let mut im = ImEngine::new(); + let mut mdns = DummyMdns {}; + let matter = matter(&mut mdns); + let mut im = ImEngine::new(&matter); // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); - im.acl_mgr.add(acl).unwrap(); + im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_attwrite = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep0_attwrite = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let val0 = 10u16; @@ -676,7 +575,7 @@ fn test_write_data_ver() { let attr_data0 = EncodeValue::Value(&val0); let attr_data1 = EncodeValue::Value(&val1); - let initial_data_ver = read_cluster_id_data_ver(&im, 0); + let initial_data_ver = im.echo_cluster(0).data_ver.get(); // Test 1: Write with correct dataversion should succeed let input_correct_dataver = &[AttrData::new( @@ -684,14 +583,13 @@ fn test_write_data_ver() { AttrPath::new(&ep0_attwrite), attr_data0, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); // Test 2: Write with incorrect dataversion should fail // Now the data version would have incremented due to the previous write @@ -700,8 +598,7 @@ fn test_write_data_ver() { AttrPath::new(&ep0_attwrite), attr_data1, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, @@ -711,26 +608,25 @@ fn test_write_data_ver() { 0, )], ); - assert_eq!(AttrValue::Uint16(val0), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val0, im.echo_cluster(0).att_write); // Test 3: Wildcard write with incorrect dataversion should ignore that cluster // In this case, while the data version is correct for endpoint 0, the endpoint 1's // data version would not match - let new_data_ver = read_cluster_id_data_ver(&im, 0); + let new_data_ver = im.echo_cluster(0).data_ver.get(); let input_correct_dataver = &[AttrData::new( Some(new_data_ver), AttrPath::new(&wc_ep_attwrite), attr_data1, )]; - handle_write_reqs( - &mut im, + im.handle_write_reqs( peer, None, input_correct_dataver, &[AttrStatus::new(&ep0_attwrite, IMStatusCode::Success, 0)], ); - assert_eq!(AttrValue::Uint16(val1), read_cluster_id_write_attr(&im, 0)); + assert_eq!(val1, im.echo_cluster(0).att_write); assert_eq!(initial_data_ver + 1, new_data_ver); } diff --git a/matter/tests/data_model/attribute_lists.rs b/matter/tests/data_model/attribute_lists.rs index 62e79c4..ace1f3d 100644 --- a/matter/tests/data_model/attribute_lists.rs +++ b/matter/tests/data_model/attribute_lists.rs @@ -16,36 +16,22 @@ */ use matter::{ - data_model::{core::DataModel, objects::EncodeValue}, + data_model::objects::EncodeValue, interaction_model::{ - core::{IMStatusCode, OpCode}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrStatus}, messages::GenericPath, - messages::{ - ib::{AttrData, AttrPath, AttrStatus}, - msg::{WriteReq, WriteResp}, - }, }, - tlv::{self, FromTLV, Nullable}, + mdns::DummyMdns, + tlv::Nullable, }; use crate::common::{ echo_cluster::{self, TestChecker}, - im_engine::im_engine, + im_engine::{matter, ImEngine}, }; // Helper for handling Write Attribute sequences -fn handle_write_reqs(input: &[AttrData], expected: &[AttrStatus]) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (dm, _, out_buf) = im_engine(OpCode::WriteRequest, &write_req, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let resp = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(resp.write_responses, expected); - dm -} - #[test] /// This tests all the attribute list operations /// add item, edit item, delete item, overwrite list, delete list @@ -67,14 +53,14 @@ fn attr_list_ops() { let att_data = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWriteList as u32), + Some(echo_cluster::AttributesDiscriminants::AttWriteList as u32), ); let mut att_path = AttrPath::new(&att_data); // Test 1: Add Operation - add val0 let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -84,7 +70,7 @@ fn attr_list_ops() { // Test 2: Another Add Operation - add val1 let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val1))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -95,7 +81,7 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(1)); let input = &[AttrData::new(None, att_path, EncodeValue::Value(&val0))]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -106,7 +92,7 @@ fn attr_list_ops() { att_path.list_index = Some(Nullable::NotNull(0)); let input = &[AttrData::new(None, att_path, delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -122,7 +108,7 @@ fn attr_list_ops() { EncodeValue::Value(&overwrite_val), )]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); @@ -133,7 +119,7 @@ fn attr_list_ops() { att_path.list_index = None; let input = &[AttrData::new(None, att_path, delete_all)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; - let _ = handle_write_reqs(input, expected); + ImEngine::new_with_write_reqs(&matter(&mut DummyMdns), input, expected); { let tc = tc_handle.lock().unwrap(); diff --git a/matter/tests/data_model/attributes.rs b/matter/tests/data_model/attributes.rs index 091b89f..17e4112 100644 --- a/matter/tests/data_model/attributes.rs +++ b/matter/tests/data_model/attributes.rs @@ -18,53 +18,26 @@ use matter::{ data_model::{ cluster_on_off, - core::DataModel, - objects::{AttrValue, EncodeValue, GlobalElements}, + objects::{EncodeValue, GlobalElements}, }, interaction_model::{ - core::{IMStatusCode, OpCode}, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, messages::GenericPath, - messages::{ - ib::{AttrData, AttrPath, AttrResp, AttrStatus}, - msg::{ReadReq, ReportDataMsg, WriteReq, WriteResp}, - }, }, - tlv::{self, ElementType, FromTLV, TLVElement, TLVWriter, TagType}, + mdns::DummyMdns, + tlv::{ElementType, TLVElement, TLVWriter, TagType}, }; use crate::{ attr_data, attr_data_path, attr_status, - common::{attributes::*, echo_cluster, im_engine::im_engine}, + common::{ + attributes::*, + echo_cluster, + im_engine::{matter, ImEngine}, + }, }; -fn handle_read_reqs(input: &[AttrPath], expected: &[AttrResp]) { - let mut out_buf = [0u8; 400]; - let received = gen_read_reqs_output(input, &mut out_buf); - assert_attr_report(&received, expected) -} - -// Helper for handling Read Req sequences -fn gen_read_reqs_output<'a>(input: &[AttrPath], out_buf: &'a mut [u8]) -> ReportDataMsg<'a> { - let read_req = ReadReq::new(true).set_attr_requests(input); - let (_, _, out_buf) = im_engine(OpCode::ReadRequest, &read_req, out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - ReportDataMsg::from_tlv(&root).unwrap() -} - -// Helper for handling Write Attribute sequences -fn handle_write_reqs(input: &[AttrData], expected: &[AttrStatus]) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (dm, _, out_buf) = im_engine(OpCode::WriteRequest, &write_req, &mut out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let response = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(response.write_responses, expected); - - dm -} - #[test] fn test_read_success() { // 3 Attr Read Requests @@ -76,17 +49,17 @@ fn test_read_success() { let ep0_att1 = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let ep1_att2 = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ); let ep1_attcustom = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ); let input = &[ AttrPath::new(&ep0_att1), @@ -101,7 +74,7 @@ fn test_read_success() { ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -118,17 +91,17 @@ fn test_read_unsupported_fields() { let invalid_endpoint = GenericPath::new( Some(2), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let invalid_cluster = GenericPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let invalid_cluster_wc_endpoint = GenericPath::new( None, Some(0x1234), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ); let invalid_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let invalid_attribute_wc_endpoint = @@ -148,7 +121,7 @@ fn test_read_unsupported_fields() { attr_status!(&invalid_cluster, IMStatusCode::UnsupportedCluster), attr_status!(&invalid_attribute, IMStatusCode::UnsupportedAttribute), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -161,7 +134,7 @@ fn test_read_wc_endpoint_all_have_clusters() { let wc_ep_att1 = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&wc_ep_att1)]; @@ -169,17 +142,17 @@ fn test_read_wc_endpoint_all_have_clusters() { attr_data!( 0, echo_cluster::ID, - echo_cluster::Attributes::Att1, + echo_cluster::AttributesDiscriminants::Att1, ElementType::U16(0x1234) ), attr_data!( 1, echo_cluster::ID, - echo_cluster::Attributes::Att1, + echo_cluster::AttributesDiscriminants::Att1, ElementType::U16(0x1234) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -192,7 +165,7 @@ fn test_read_wc_endpoint_only_1_has_cluster() { let wc_ep_onoff = GenericPath::new( None, Some(cluster_on_off::ID), - Some(cluster_on_off::Attributes::OnOff as u32), + Some(cluster_on_off::AttributesDiscriminants::OnOff as u32), ); let input = &[AttrPath::new(&wc_ep_onoff)]; @@ -200,11 +173,11 @@ fn test_read_wc_endpoint_only_1_has_cluster() { GenericPath::new( Some(1), Some(cluster_on_off::ID), - Some(cluster_on_off::Attributes::OnOff as u32) + Some(cluster_on_off::AttributesDiscriminants::OnOff as u32) ), ElementType::False )]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -221,10 +194,10 @@ fn test_read_wc_endpoint_wc_attribute() { &[ GlobalElements::FeatureMap as u16, GlobalElements::AttributeList as u16, - echo_cluster::Attributes::Att1 as u16, - echo_cluster::Attributes::Att2 as u16, - echo_cluster::Attributes::AttWrite as u16, - echo_cluster::Attributes::AttCustom as u16, + echo_cluster::AttributesDiscriminants::Att1 as u16, + echo_cluster::AttributesDiscriminants::Att2 as u16, + echo_cluster::AttributesDiscriminants::AttWrite as u16, + echo_cluster::AttributesDiscriminants::AttCustom as u16, ], ); let attr_list_tlv = attr_list.to_tlv(); @@ -250,7 +223,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), ElementType::U16(0x1234) ), @@ -258,7 +231,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), ElementType::U16(0x5678) ), @@ -266,7 +239,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), @@ -290,7 +263,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att1 as u32), + Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), ElementType::U16(0x1234) ), @@ -298,7 +271,7 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::Att2 as u32), + Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), ElementType::U16(0x5678) ), @@ -306,12 +279,12 @@ fn test_read_wc_endpoint_wc_attribute() { GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttCustom as u32), + Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) ), ]; - handle_read_reqs(input, expected); + ImEngine::new_with_read_reqs(&matter(&mut DummyMdns), input, expected); } #[test] @@ -332,12 +305,12 @@ fn test_write_success() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[ @@ -357,24 +330,11 @@ fn test_write_success() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let dm = handle_write_reqs(input, expected); - let node = dm.node.read().unwrap(); - let echo = node.get_cluster(0, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(val0), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); - let echo = node.get_cluster(1, echo_cluster::ID).unwrap(); - assert_eq!( - AttrValue::Uint16(val1), - *echo - .base() - .read_attribute_raw(echo_cluster::Attributes::AttWrite as u16) - .unwrap() - ); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); + assert_eq!(val0, im.echo_cluster(0).att_write); + assert_eq!(val1, im.echo_cluster(1).att_write); } #[test] @@ -390,7 +350,7 @@ fn test_write_wc_endpoint() { let ep_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( None, @@ -401,38 +361,23 @@ fn test_write_wc_endpoint() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let expected = &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - let dm = handle_write_reqs(input, expected); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -455,25 +400,25 @@ fn test_write_unsupported_fields() { let invalid_endpoint = GenericPath::new( Some(4), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let invalid_cluster = GenericPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let invalid_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let wc_endpoint_invalid_cluster = GenericPath::new( None, Some(0x1234), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let wc_endpoint_invalid_attribute = GenericPath::new(None, Some(echo_cluster::ID), Some(0x1234)); let wc_cluster = GenericPath::new( Some(0), None, - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let wc_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), None); @@ -521,14 +466,11 @@ fn test_write_unsupported_fields() { AttrStatus::new(&wc_cluster, IMStatusCode::UnsupportedCluster, 0), AttrStatus::new(&wc_attribute, IMStatusCode::UnsupportedAttribute, 0), ]; - let dm = handle_write_reqs(input, expected); + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_write_reqs(&matter, input, expected); assert_eq!( - AttrValue::Uint16(echo_cluster::ATTR_WRITE_DEFAULT_VALUE), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() + echo_cluster::ATTR_WRITE_DEFAULT_VALUE, + im.echo_cluster(0).att_write ); } diff --git a/matter/tests/data_model/commands.rs b/matter/tests/data_model/commands.rs index 353b662..50c1a8a 100644 --- a/matter/tests/data_model/commands.rs +++ b/matter/tests/data_model/commands.rs @@ -17,39 +17,23 @@ use crate::{ cmd_data, - common::{commands::*, echo_cluster, im_engine::im_engine}, + common::{ + commands::*, + echo_cluster, + im_engine::{matter, ImEngine}, + }, echo_req, echo_resp, }; use matter::{ data_model::{cluster_on_off, objects::EncodeValue}, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{CmdData, CmdPath, CmdStatus}, - msg, - msg::InvReq, - }, + core::IMStatusCode, + messages::ib::{CmdData, CmdPath, CmdStatus}, }, - tlv::{self, FromTLV, TLVArray}, + mdns::DummyMdns, }; -// Helper for handling Invoke Command sequences -fn handle_commands(input: &[CmdData], expected: &[ExpectedInvResp]) { - let mut out_buf = [0u8; 400]; - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(false), - inv_requests: Some(TLVArray::Slice(input)), - }; - - let (_, _, out_buf) = im_engine(OpCode::InvokeRequest, &req, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, expected) -} - #[test] fn test_invoke_cmds_success() { // 2 echo Requests @@ -59,7 +43,7 @@ fn test_invoke_cmds_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -75,17 +59,17 @@ fn test_invoke_cmds_unsupported_fields() { let invalid_endpoint = CmdPath::new( Some(2), Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_cluster = CmdPath::new( Some(0), Some(0x1234), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_cluster_wc_endpoint = CmdPath::new( None, Some(0x1234), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let invalid_command = CmdPath::new(Some(0), Some(echo_cluster::ID), Some(0x1234)); let invalid_command_wc_endpoint = CmdPath::new(None, Some(echo_cluster::ID), Some(0x1234)); @@ -114,7 +98,7 @@ fn test_invoke_cmds_unsupported_fields() { 0, )), ]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -125,11 +109,11 @@ fn test_invoke_cmd_wc_endpoint_all_have_clusters() { let path = CmdPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u16), + Some(echo_cluster::Commands::EchoReq as u32), ); let input = &[cmd_data!(path, 5)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 15)]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } #[test] @@ -141,12 +125,12 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { let target = CmdPath::new( None, Some(cluster_on_off::ID), - Some(cluster_on_off::Commands::On as u16), + Some(cluster_on_off::CommandsDiscriminants::On as u32), ); let expected_path = CmdPath::new( Some(1), Some(cluster_on_off::ID), - Some(cluster_on_off::Commands::On as u16), + Some(cluster_on_off::CommandsDiscriminants::On as u32), ); let input = &[cmd_data!(target, 1)]; let expected = &[ExpectedInvResp::Status(CmdStatus::new( @@ -154,5 +138,5 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { IMStatusCode::Success, 0, ))]; - handle_commands(input, expected); + ImEngine::new_with_commands(&matter(&mut DummyMdns), input, expected); } diff --git a/matter/tests/data_model/timed_requests.rs b/matter/tests/data_model/timed_requests.rs index 1c4a941..cf5ddbd 100644 --- a/matter/tests/data_model/timed_requests.rs +++ b/matter/tests/data_model/timed_requests.rs @@ -15,112 +15,27 @@ * limitations under the License. */ -use core::time; -use std::thread; - use matter::{ - data_model::{ - core::DataModel, - objects::{AttrValue, EncodeValue}, - }, + data_model::objects::EncodeValue, interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ib::CmdData, ib::CmdPath, msg::InvReq, GenericPath}, - messages::{ - ib::{AttrData, AttrPath, AttrStatus}, - msg::{self, StatusResp, TimedReq, WriteReq, WriteResp}, - }, + core::IMStatusCode, + messages::ib::{AttrData, AttrPath, AttrStatus}, + messages::{ib::CmdData, ib::CmdPath, GenericPath}, }, - tlv::{self, FromTLV, TLVArray, TLVWriter, ToTLV}, - transport::exchange::{self, Exchange}, + mdns::DummyMdns, + tlv::TLVWriter, }; use crate::{ common::{ commands::*, echo_cluster, - im_engine::{ImEngine, ImInput}, + handlers::{TimedInvResponse, WriteResponse}, + im_engine::{matter, ImEngine}, }, echo_req, echo_resp, }; -fn handle_timed_reqs<'a>( - opcode: OpCode, - request: &dyn ToTLV, - timeout: u16, - delay: u16, - output: &'a mut [u8], -) -> (u8, DataModel, &'a [u8]) { - let mut im_engine = ImEngine::new(); - // Use the same exchange for all parts of the transaction - im_engine.exch = Some(Exchange::new(1, 0, exchange::Role::Responder)); - - if timeout != 0 { - // Send Timed Req - let mut tmp_buf = [0u8; 400]; - let timed_req = TimedReq { timeout }; - let im_input = ImInput::new(OpCode::TimedRequest, &timed_req); - let (_, out_buf) = im_engine.process(&im_input, &mut tmp_buf); - tlv::print_tlv_list(out_buf); - } else { - println!("Skipping timed request"); - } - - // Process any delays - let delay = time::Duration::from_millis(delay.into()); - thread::sleep(delay); - - // Send Write Req - let input = ImInput::new(opcode, request); - let (resp_opcode, output) = im_engine.process(&input, output); - (resp_opcode, im_engine.dm, output) -} -enum WriteResponse<'a> { - TransactionError, - TransactionSuccess(&'a [AttrStatus]), -} - -// Helper for handling Write Attribute sequences -fn handle_timed_write_reqs( - input: &[AttrData], - expected: &WriteResponse, - timeout: u16, - delay: u16, -) -> DataModel { - let mut out_buf = [0u8; 400]; - let write_req = WriteReq::new(false, input); - - let (resp_opcode, dm, out_buf) = handle_timed_reqs( - OpCode::WriteRequest, - &write_req, - timeout, - delay, - &mut out_buf, - ); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - match expected { - WriteResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::WriteResponse) - ); - let resp = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(resp.write_responses, t); - } - WriteResponse::TransactionError => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, IMStatusCode::Timeout); - } - } - dm -} - #[test] fn test_timed_write_fail_and_success() { // - 1 Timed Attr Write Transaction should fail due to timeout @@ -134,7 +49,7 @@ fn test_timed_write_fail_and_success() { let ep_att = GenericPath::new( None, Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let input = &[AttrData::new( None, @@ -145,13 +60,13 @@ fn test_timed_write_fail_and_success() { let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let ep1_att = GenericPath::new( Some(1), Some(echo_cluster::ID), - Some(echo_cluster::Attributes::AttWrite as u32), + Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); let expected = &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), @@ -159,73 +74,25 @@ fn test_timed_write_fail_and_success() { ]; // Test with incorrect handling - handle_timed_write_reqs(input, &WriteResponse::TransactionError, 400, 500); + ImEngine::new_with_timed_write_reqs( + &matter(&mut DummyMdns), + input, + &WriteResponse::TransactionError, + 400, + 500, + ); // Test with correct handling - let dm = handle_timed_write_reqs(input, &WriteResponse::TransactionSuccess(expected), 400, 0); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() + let mut mdns = DummyMdns; + let matter = matter(&mut mdns); + let im = ImEngine::new_with_timed_write_reqs( + &matter, + input, + &WriteResponse::TransactionSuccess(expected), + 400, + 0, ); - assert_eq!( - AttrValue::Uint16(val0), - dm.read_attribute_raw( - 0, - echo_cluster::ID, - echo_cluster::Attributes::AttWrite as u16 - ) - .unwrap() - ); -} - -enum TimedInvResponse<'a> { - TransactionError(IMStatusCode), - TransactionSuccess(&'a [ExpectedInvResp]), -} -// Helper for handling Invoke Command sequences -fn handle_timed_commands( - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, -) -> DataModel { - let mut out_buf = [0u8; 400]; - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(set_timed_request), - inv_requests: Some(TLVArray::Slice(input)), - }; - - let (resp_opcode, dm, out_buf) = - handle_timed_reqs(OpCode::InvokeRequest, &req, timeout, delay, &mut out_buf); - tlv::print_tlv_list(out_buf); - let root = tlv::get_root_node_struct(out_buf).unwrap(); - - match expected { - TimedInvResponse::TransactionSuccess(t) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::InvokeResponse) - ); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, t) - } - TimedInvResponse::TransactionError(e) => { - assert_eq!( - num::FromPrimitive::from_u8(resp_opcode), - Some(OpCode::StatusResponse) - ); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, *e); - } - } - dm + assert_eq!(val0, im.echo_cluster(0).att_write); } #[test] @@ -235,7 +102,8 @@ fn test_timed_cmd_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionSuccess(expected), 400, @@ -250,7 +118,8 @@ fn test_timed_cmd_timeout() { let _ = env_logger::try_init(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::Timeout), 400, @@ -265,7 +134,8 @@ fn test_timed_cmd_timedout_mismatch() { let _ = env_logger::try_init(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 400, @@ -274,7 +144,8 @@ fn test_timed_cmd_timedout_mismatch() { ); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - handle_timed_commands( + ImEngine::new_with_timed_commands( + &matter(&mut DummyMdns), input, &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), 0, diff --git a/matter/tests/data_model_tests.rs b/matter/tests/data_model_tests.rs index 392909f..803c4c5 100644 --- a/matter/tests/data_model_tests.rs +++ b/matter/tests/data_model_tests.rs @@ -22,6 +22,6 @@ mod data_model { mod attribute_lists; mod attributes; mod commands; - mod long_reads; + // TODO mod long_reads; mod timed_requests; } diff --git a/matter/tests/interaction_model.rs b/matter/tests/interaction_model.rs index 5e88f8a..07d114e 100644 --- a/matter/tests/interaction_model.rs +++ b/matter/tests/interaction_model.rs @@ -15,116 +15,70 @@ * limitations under the License. */ -use boxslab::Slab; +use matter::data_model::core::DataHandler; use matter::error::Error; +use matter::interaction_model::core::Interaction; +use matter::interaction_model::core::InteractionModel; use matter::interaction_model::core::OpCode; -use matter::interaction_model::messages::msg::InvReq; -use matter::interaction_model::messages::msg::WriteReq; -use matter::interaction_model::InteractionConsumer; -use matter::interaction_model::InteractionModel; -use matter::interaction_model::Transaction; -use matter::tlv::TLVWriter; +use matter::interaction_model::core::Transaction; use matter::transport::exchange::Exchange; use matter::transport::exchange::ExchangeCtx; use matter::transport::network::Address; use matter::transport::packet::Packet; -use matter::transport::packet::PacketPool; -use matter::transport::proto_demux::HandleProto; -use matter::transport::proto_demux::ProtoCtx; -use matter::transport::proto_demux::ResponseRequired; +use matter::transport::proto_ctx::ProtoCtx; use matter::transport::session::SessionMgr; +use matter::utils::epoch::dummy_epoch; +use matter::utils::rand::dummy_rand; use std::net::Ipv4Addr; use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; struct Node { pub endpoint: u16, pub cluster: u32, - pub command: u32, + pub command: u16, pub variable: u8, } struct DataModel { - node: Arc>, + node: Node, } impl DataModel { pub fn new(node: Node) -> Self { - DataModel { - node: Arc::new(Mutex::new(node)), - } + DataModel { node } } } -impl Clone for DataModel { - fn clone(&self) -> Self { - Self { - node: self.node.clone(), - } - } -} - -impl InteractionConsumer for DataModel { - fn consume_invoke_cmd( - &self, - inv_req_msg: &InvReq, - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - if let Some(inv_requests) = &inv_req_msg.inv_requests { - for i in inv_requests.iter() { - let data = if let Some(data) = i.data.unwrap_tlv() { - data - } else { - continue; - }; - let cmd_path_ib = i.path; - let mut common_data = self.node.lock().unwrap(); - common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); - common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); - common_data.command = cmd_path_ib.path.leaf.unwrap_or(0); - data.confirm_struct().unwrap(); - common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); +impl DataHandler for DataModel { + fn handle( + &mut self, + interaction: &Interaction, + _tx: &mut Packet, + _transaction: &mut Transaction, + ) -> Result { + match interaction { + Interaction::Invoke(req) => { + if let Some(inv_requests) = &req.inv_requests { + for i in inv_requests.iter() { + let data = if let Some(data) = i.data.unwrap_tlv() { + data + } else { + continue; + }; + let cmd_path_ib = i.path; + let mut common_data = &mut self.node; + common_data.endpoint = cmd_path_ib.path.endpoint.unwrap_or(1); + common_data.cluster = cmd_path_ib.path.cluster.unwrap_or(0); + common_data.command = cmd_path_ib.path.leaf.unwrap_or(0) as u16; + data.confirm_struct().unwrap(); + common_data.variable = data.find_tag(0).unwrap().u8().unwrap(); + } + } } + _ => (), } - Ok(()) - } - - fn consume_read_attr( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - Ok(()) - } - - fn consume_write_attr( - &self, - _req: &WriteReq, - _trans: &mut Transaction, - _tlvwriter: &mut TLVWriter, - ) -> Result<(), Error> { - Ok(()) - } - - fn consume_status_report( - &self, - _req: &matter::interaction_model::messages::msg::StatusResp, - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, ResponseRequired), Error> { - Ok((OpCode::StatusResponse, ResponseRequired::No)) - } - - fn consume_subscribe( - &self, - _req: &[u8], - _trans: &mut Transaction, - _tw: &mut TLVWriter, - ) -> Result<(OpCode, matter::transport::proto_demux::ResponseRequired), Error> { - Ok((OpCode::StatusResponse, ResponseRequired::No)) + Ok(false) } } @@ -135,9 +89,9 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode command: 0, variable: 0, }); - let mut interaction_model = InteractionModel::new(Box::new(data_model.clone())); + let mut interaction_model = InteractionModel(data_model); let mut exch: Exchange = Default::default(); - let mut sess_mgr: SessionMgr = Default::default(); + let mut sess_mgr = SessionMgr::new(dummy_epoch, dummy_rand); let sess_idx = sess_mgr .get_or_add( 0, @@ -153,24 +107,27 @@ fn handle_data(action: OpCode, data_in: &[u8], data_out: &mut [u8]) -> (DataMode let exch_ctx = ExchangeCtx { exch: &mut exch, sess, + epoch: dummy_epoch, }; - let mut rx = Slab::::try_new(Packet::new_rx().unwrap()).unwrap(); - let tx = Slab::::try_new(Packet::new_tx().unwrap()).unwrap(); + let mut rx_buf = [0; 1500]; + let mut tx_buf = [0; 1500]; + let mut rx = Packet::new_rx(&mut rx_buf); + let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet rx.set_proto_id(0x01); rx.set_proto_opcode(action as u8); rx.peer = Address::default(); let in_data_len = data_in.len(); - let rx_buf = rx.as_borrow_slice(); + let rx_buf = rx.as_mut_slice(); rx_buf[..in_data_len].copy_from_slice(data_in); - let mut ctx = ProtoCtx::new(exch_ctx, rx, tx); + let mut ctx = ProtoCtx::new(exch_ctx, &rx, &mut tx); - interaction_model.handle_proto_id(&mut ctx).unwrap(); + interaction_model.handle(&mut ctx).unwrap(); - let out_len = ctx.tx.as_borrow_slice().len(); - data_out[..out_len].copy_from_slice(ctx.tx.as_borrow_slice()); - (data_model, out_len) + let out_len = ctx.tx.as_mut_slice().len(); + data_out[..out_len].copy_from_slice(ctx.tx.as_mut_slice()); + (interaction_model.0, out_len) } #[test] @@ -186,7 +143,7 @@ fn test_valid_invoke_cmd() -> Result<(), Error> { let mut out_buf: [u8; 20] = [0; 20]; let (data_model, _) = handle_data(OpCode::InvokeRequest, &b, &mut out_buf); - let data = data_model.node.lock().unwrap(); + let data = &data_model.node; assert_eq!(data.endpoint, 0); assert_eq!(data.cluster, 49); assert_eq!(data.command, 12);