From 7437cf2c94152560e945d5d058d137d3c302615b Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 29 Apr 2023 10:03:34 +0000 Subject: [PATCH] Simple persistance via TLV --- examples/onoff_light/src/main.rs | 23 +- matter/src/acl.rs | 64 +--- matter/src/core.rs | 18 + matter/src/crypto/crypto_dummy.rs | 8 +- matter/src/crypto/crypto_esp_mbedtls.rs | 9 +- matter/src/crypto/mod.rs | 36 ++ matter/src/fabric.rs | 434 ++++-------------------- matter/src/group_keys.rs | 14 +- matter/src/persist.rs | 239 +++---------- matter/src/tlv/traits.rs | 63 +++- 10 files changed, 262 insertions(+), 646 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 7ffc70e..b6d2588 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -27,6 +27,7 @@ use matter::data_model::root_endpoint; use matter::data_model::sdm::dev_att::DevAttDataFetcher; use matter::data_model::system_model::descriptor; use matter::interaction_model::core::InteractionModel; +use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::{ mgr::RecvAction, mgr::TransportMgr, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE, @@ -56,6 +57,18 @@ fn main() { let dev_att = dev_att::HardCodedDevAtt::new(); + let psm = persist::FilePsm::new(std::env::temp_dir().join("matter-iot")).unwrap(); + + let mut buf = [0; 4096]; + + if let Some(data) = psm.load("fabrics", &mut buf).unwrap() { + matter.load_fabrics(data).unwrap(); + } + + if let Some(data) = psm.load("acls", &mut buf).unwrap() { + matter.load_acls(data).unwrap(); + } + matter .start::<4096>( CommissioningData { @@ -63,7 +76,7 @@ fn main() { verifier: VerifierData::new_with_pw(123456, *matter.borrow()), discriminator: 250, }, - &mut [0; 4096], + &mut buf, ) .unwrap(); @@ -114,6 +127,14 @@ fn main() { } } } + + if let Some(data) = matter.store_fabrics(&mut buf).unwrap() { + psm.store("fabrics", data).unwrap(); + } + + if let Some(data) = matter.store_acls(&mut buf).unwrap() { + psm.store("acls", data).unwrap(); + } } }); } diff --git a/matter/src/acl.rs b/matter/src/acl.rs index d73ce47..dea592b 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -22,7 +22,6 @@ use crate::{ error::Error, fabric, interaction_model::messages::GenericPath, - persist::Psm, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, @@ -390,10 +389,8 @@ impl AclEntry { } const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = [Option; MAX_ACL_ENTRIES]; -const ACL_KV_ENTRY: &str = "acl"; -const ACL_KV_MAX_SIZE: usize = 300; +type AclEntries = [Option; MAX_ACL_ENTRIES]; pub struct AclMgr { entries: AclEntries, @@ -505,30 +502,8 @@ impl AclMgr { false } - pub fn store(&mut self, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - if self.changed { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())?; - - self.changed = false; - } - - Ok(()) - } - - pub fn load(&mut self, psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf)?; - let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; + pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { + let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; @@ -536,37 +511,20 @@ impl AclMgr { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { + pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result, Error> { if self.changed { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let mut wb = WriteBuf::new(&mut buf); + let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); self.entries.to_tlv(&mut tw, TagType::Anonymous)?; - psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice()).await?; self.changed = false; + + let len = tw.get_tail(); + + Ok(Some(&buf[..len])) + } else { + Ok(None) } - - Ok(()) - } - - #[cfg(feature = "nightly")] - pub async fn load_async(&mut self, psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut buf = [0u8; ACL_KV_MAX_SIZE]; - let acl_tlvs = psm.get_kv_slice(ACL_KV_ENTRY, &mut buf).await?; - let root = TLVList::new(acl_tlvs).iter().next().ok_or(Error::Invalid)?; - - self.entries = AclEntries::from_tlv(&root)?; - self.changed = false; - - Ok(()) } /// Traverse fabric specific entries to find the index diff --git a/matter/src/core.rs b/matter/src/core.rs index 0939a4a..e2e6b59 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -98,6 +98,24 @@ impl<'a> Matter<'a> { self.dev_det } + pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { + self.fabric_mgr + .borrow_mut() + .load(data, &mut self.mdns_mgr.borrow_mut()) + } + + pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { + self.acl_mgr.borrow_mut().load(data) + } + + pub fn store_fabrics<'b>(&self, buf: &'b mut [u8]) -> Result, Error> { + self.fabric_mgr.borrow_mut().store(buf) + } + + pub fn store_acls<'b>(&self, buf: &'b mut [u8]) -> Result, Error> { + self.acl_mgr.borrow_mut().store(buf) + } + pub fn start( &self, dev_comm: CommissioningData, diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index f193b20..acdae09 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -68,8 +68,6 @@ impl KeyPair { } pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result { - error!("This API should never get called"); - Ok(Self {}) } @@ -85,13 +83,11 @@ impl KeyPair { } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) } pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) } pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index fe72337..4eee8a7 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -70,8 +70,6 @@ impl KeyPair { } pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result { - error!("This API should never get called"); - Ok(Self {}) } @@ -87,8 +85,11 @@ impl KeyPair { } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { - error!("This API should never get called"); - Err(Error::Invalid) + Ok(0) + } + + pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result { + Ok(0) } pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 5c73ff2..27ba187 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -14,6 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +use crate::{ + error::Error, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, +}; pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; @@ -68,6 +72,38 @@ pub mod crypto_dummy; )))] pub use self::crypto_dummy::*; +impl<'a> FromTLV<'a> for KeyPair { + fn from_tlv(t: &crate::tlv::TLVElement<'a>) -> Result + where + Self: Sized, + { + t.confirm_array()?.enter(); + + if let Some(mut array) = t.enter() { + let pub_key = array.next().ok_or(Error::Invalid)?.slice()?; + let priv_key = array.next().ok_or(Error::Invalid)?.slice()?; + + KeyPair::new_from_components(pub_key, priv_key) + } else { + Err(Error::Invalid) + } + } +} + +impl ToTLV for KeyPair { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + let mut buf = [0; 1024]; // TODO + + tw.start_array(tag)?; + + let size = self.get_public_key(&mut buf)?; + tw.str16(TagType::Anonymous, &buf[..size])?; + + let size = self.get_private_key(&mut buf)?; + tw.str16(TagType::Anonymous, &buf[..size]) + } +} + #[cfg(test)] mod tests { use crate::error::Error; diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 688c56c..6f3ff0e 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -18,7 +18,8 @@ use core::fmt::Write; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use log::{error, info}; +use heapless::{String, Vec}; +use log::info; use crate::{ cert::{Cert, MAX_CERT_TLV_LEN}, @@ -26,32 +27,12 @@ use crate::{ error::Error, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - persist::Psm, - tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{FromTLV, OctetStr, TLVElement, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + utils::writebuf::WriteBuf, }; const COMPRESSED_FABRIC_ID_LEN: usize = 8; -macro_rules! fb_key { - ($index:ident, $key:ident, $buf:expr) => {{ - use core::fmt::Write; - - $buf = "".into(); - write!(&mut $buf, "fb{}{}", $index, $key).unwrap(); - - &$buf - }}; -} - -const ST_VID: &str = "vid"; -const ST_RCA: &str = "rca"; -const ST_ICA: &str = "ica"; -const ST_NOC: &str = "noc"; -const ST_IPK: &str = "ipk"; -const ST_LBL: &str = "label"; -const ST_PBKEY: &str = "pubkey"; -const ST_PRKEY: &str = "privkey"; - #[allow(dead_code)] #[derive(Debug, ToTLV)] #[tlvargs(lifetime = "'a", start = 1)] @@ -66,18 +47,18 @@ pub struct FabricDescriptor<'a> { pub fab_idx: Option, } -#[derive(Debug)] +#[derive(Debug, ToTLV, FromTLV)] pub struct Fabric { node_id: u64, fabric_id: u64, vendor_id: u16, key_pair: KeyPair, - pub root_ca: heapless::Vec, - pub icac: Option>, - pub noc: heapless::Vec, + pub root_ca: Vec, + pub icac: Option>, + pub noc: Vec, pub ipk: KeySet, - label: heapless::String<32>, - mdns_service_name: heapless::String<33>, + label: String<32>, + mdns_service_name: String<33>, } impl Fabric { @@ -199,234 +180,14 @@ impl Fabric { Ok(desc) } - - fn store(&self, index: usize, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)?; - psm.set_kv_slice( - fb_key!(index, ST_ICA, _kb), - self.icac.as_deref().unwrap_or(&[]), - )?; - - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)?; - psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; - psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key)?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key)?; - - psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into())?; - Ok(()) - } - - fn load(index: usize, psm: T) -> Result - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - - let root_ca = - heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?) - .unwrap(); - - let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; - let icac = if !icac.is_empty() { - Some(heapless::Vec::from_slice(icac).unwrap()) - } else { - None - }; - - let noc = - heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?) - .unwrap(); - - let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; - let label: heapless::String<32> = core::str::from_utf8(label) - .map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })? - .into(); - - let ipk = psm.get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf)?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let pub_key = psm.get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf)?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let priv_key = psm.get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf)?; - let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - - let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb))?; - - Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) - } - - fn remove(index: usize, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.remove(fb_key!(index, ST_RCA, _kb))?; - psm.remove(fb_key!(index, ST_ICA, _kb))?; - - psm.remove(fb_key!(index, ST_NOC, _kb))?; - - psm.remove(fb_key!(index, ST_LBL, _kb))?; - - psm.remove(fb_key!(index, ST_IPK, _kb))?; - - psm.remove(fb_key!(index, ST_PBKEY, _kb))?; - psm.remove(fb_key!(index, ST_PRKEY, _kb))?; - - psm.remove(fb_key!(index, ST_VID, _kb))?; - - Ok(()) - } - - #[cfg(feature = "nightly")] - async fn store_async(&self, index: usize, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca) - .await?; - - psm.set_kv_slice( - fb_key!(index, ST_ICA, _kb), - self.icac.as_deref().unwrap_or(&[]), - ) - .await?; - - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc) - .await?; - psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) - .await?; - psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes()) - .await?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let len = self.key_pair.get_public_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PBKEY, _kb), key).await?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let len = self.key_pair.get_private_key(&mut buf)?; - let key = &buf[..len]; - psm.set_kv_slice(fb_key!(index, ST_PRKEY, _kb), key).await?; - - psm.set_kv_u64(fb_key!(index, ST_VID, _kb), self.vendor_id.into()) - .await?; - Ok(()) - } - - #[cfg(feature = "nightly")] - async fn load_async(index: usize, psm: T) -> Result - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - - let root_ca = heapless::Vec::from_slice( - psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) - .await?, - ) - .unwrap(); - - let icac = psm - .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) - .await?; - let icac = if !icac.is_empty() { - Some(heapless::Vec::from_slice(icac).unwrap()) - } else { - None - }; - - let noc = heapless::Vec::from_slice( - psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) - .await?, - ) - .unwrap(); - - let label = psm - .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) - .await?; - let label: heapless::String<32> = core::str::from_utf8(label) - .map_err(|_| { - error!("Couldn't read label"); - Error::Invalid - })? - .into(); - - let ipk = psm - .get_kv_slice(fb_key!(index, ST_IPK, _kb), &mut buf) - .await?; - - let mut buf = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let pub_key = psm - .get_kv_slice(fb_key!(index, ST_PBKEY, _kb), &mut buf) - .await?; - - let mut buf = [0_u8; crypto::BIGNUM_LEN_BYTES]; - let priv_key = psm - .get_kv_slice(fb_key!(index, ST_PRKEY, _kb), &mut buf) - .await?; - let keypair = KeyPair::new_from_components(pub_key, priv_key)?; - - let vendor_id = psm.get_kv_u64(fb_key!(index, ST_VID, _kb)).await?; - - Fabric::new(keypair, root_ca, icac, noc, ipk, vendor_id as u16, &label) - } - - #[cfg(feature = "nightly")] - async fn remove_async(index: usize, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - let mut _kb = heapless::String::<32>::new(); - - psm.remove(fb_key!(index, ST_RCA, _kb)).await?; - psm.remove(fb_key!(index, ST_ICA, _kb)).await?; - - psm.remove(fb_key!(index, ST_NOC, _kb)).await?; - - psm.remove(fb_key!(index, ST_LBL, _kb)).await?; - - psm.remove(fb_key!(index, ST_IPK, _kb)).await?; - - psm.remove(fb_key!(index, ST_PBKEY, _kb)).await?; - psm.remove(fb_key!(index, ST_PRKEY, _kb)).await?; - - psm.remove(fb_key!(index, ST_VID, _kb)).await?; - - Ok(()) - } } pub const MAX_SUPPORTED_FABRICS: usize = 3; +type FabricEntries = [Option; MAX_SUPPORTED_FABRICS]; + pub struct FabricMgr { - // The outside world expects Fabric Index to be one more than the actual one - // since 0 is not allowed. Need to handle this cleanly somehow - fabrics: [Option; MAX_SUPPORTED_FABRICS], + fabrics: FabricEntries, changed: bool, } @@ -440,41 +201,20 @@ impl FabricMgr { } } - pub fn store(&mut self, mut psm: T) -> Result<(), Error> - where - T: Psm, - { - if self.changed { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = self.fabrics[i].as_mut() { - info!("Storing fabric at index {}", i); - fabric.store(i, &mut psm)?; - } else { - let _ = Fabric::remove(i, &mut psm); - } + pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + for fabric in &self.fabrics { + if let Some(fabric) = fabric { + mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } - - self.changed = false; } - Ok(()) - } + let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; - pub fn load(&mut self, mut psm: T, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> - where - T: Psm, - { - for i in 1..MAX_SUPPORTED_FABRICS { - let result = Fabric::load(i, &mut psm); - if let Ok(fabric) = result { - info!("Adding new fabric at index {}", i); - self.fabrics[i] = Some(fabric); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; - } else { - self.fabrics[i] = None; + self.fabrics = FabricEntries::from_tlv(&root)?; + + for fabric in &self.fabrics { + if let Some(fabric) = fabric { + mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } } @@ -483,67 +223,32 @@ impl FabricMgr { Ok(()) } - #[cfg(feature = "nightly")] - pub async fn store_async(&mut self, mut psm: T) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { + pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result, Error> { if self.changed { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = self.fabrics[i].as_mut() { - info!("Storing fabric at index {}", i); - fabric.store_async(i, &mut psm).await?; - } else { - let _ = Fabric::remove_async(i, &mut psm).await; - } - } + let mut wb = WriteBuf::new(buf); + let mut tw = TLVWriter::new(&mut wb); + self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; + + let len = tw.get_tail(); + + Ok(Some(&buf[..len])) + } else { + Ok(None) } - - Ok(()) - } - - #[cfg(feature = "nightly")] - pub async fn load_async( - &mut self, - mut psm: T, - mdns_mgr: &mut MdnsMgr<'_>, - ) -> Result<(), Error> - where - T: crate::persist::asynch::AsyncPsm, - { - for i in 1..MAX_SUPPORTED_FABRICS { - let result = Fabric::load_async(i, &mut psm).await; - if let Ok(fabric) = result { - info!("Adding new fabric at index {}", i); - self.fabrics[i] = Some(fabric); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; - } else { - self.fabrics[i] = None; - } - } - - self.changed = false; - - Ok(()) } pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_none() { - self.fabrics[i] = Some(f); - mdns_mgr.publish_service( - &self.fabrics[i].as_ref().unwrap().mdns_service_name, - ServiceMode::Commissioned, - )?; + for (index, fabric) in self.fabrics.iter_mut().enumerate() { + if fabric.is_none() { + mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + + *fabric = Some(f); self.changed = true; - return Ok(i as u8); + return Ok((index + 1) as u8); } } @@ -551,20 +256,24 @@ impl FabricMgr { } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { - if let Some(f) = self.fabrics[fab_idx as usize].take() { - mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; - self.changed = true; - Ok(()) + if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { + if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { + mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; + Ok(()) + } else { + Err(Error::NotFound) + } } else { Err(Error::NotFound) } } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { + for (index, fabric) in self.fabrics.iter().enumerate() { + if let Some(fabric) = fabric { if fabric.match_dest_id(random, target).is_ok() { - return Ok(i); + return Ok(index + 1); } } } @@ -572,26 +281,19 @@ impl FabricMgr { } pub fn get_fabric(&self, idx: usize) -> Result, Error> { - Ok(self.fabrics[idx].as_ref()) + if idx == 0 { + Ok(None) + } else { + Ok(self.fabrics[idx - 1].as_ref()) + } } pub fn is_empty(&self) -> bool { - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_some() { - return false; - } - } - true + !self.fabrics.iter().any(Option::is_some) } pub fn used_count(&self) -> usize { - let mut count = 0; - for i in 1..MAX_SUPPORTED_FABRICS { - if self.fabrics[i].is_some() { - count += 1; - } - } - count + self.fabrics.iter().filter(|f| f.is_some()).count() } // Parameters to T are the Fabric and its Fabric Index @@ -599,25 +301,27 @@ impl FabricMgr { where T: FnMut(&Fabric, u8) -> Result<(), Error>, { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { - f(fabric, i as u8)?; + for (index, fabric) in self.fabrics.iter().enumerate() { + if let Some(fabric) = fabric { + f(fabric, (index + 1) as u8)?; } } Ok(()) } pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { - let index = index as usize; if !label.is_empty() { - for i in 1..MAX_SUPPORTED_FABRICS { - if let Some(fabric) = &self.fabrics[i] { - if fabric.label == label { - return Err(Error::Invalid); - } - } + if self + .fabrics + .iter() + .filter_map(|f| f.as_ref()) + .any(|f| f.label == label) + { + return Err(Error::Invalid); } } + + let index = (index - 1) as usize; if let Some(fabric) = &mut self.fabrics[index] { fabric.label = label.into(); self.changed = true; diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index 1dc1c40..d4e9765 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -15,12 +15,18 @@ * limitations under the License. */ -use crate::{crypto, error::Error}; +use crate::{ + crypto::{self, SYMM_KEY_LEN_BYTES}, + error::Error, + tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, +}; -#[derive(Debug, Default)] +type KeySetKey = [u8; SYMM_KEY_LEN_BYTES]; + +#[derive(Debug, Default, FromTLV, ToTLV)] pub struct KeySet { - pub epoch_key: [u8; crypto::SYMM_KEY_LEN_BYTES], - pub op_key: [u8; crypto::SYMM_KEY_LEN_BYTES], + pub epoch_key: KeySetKey, + pub op_key: KeySetKey, } impl KeySet { diff --git a/matter/src/persist.rs b/matter/src/persist.rs index 4bc8e24..1ea494b 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -14,216 +14,63 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -use crate::error::Error; - -pub trait Psm { - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error>; - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error>; - - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error>; - fn get_kv_u64(&self, key: &str) -> Result; - - fn remove(&mut self, key: &str) -> Result<(), Error>; -} - -impl Psm for &mut T -where - T: Psm, -{ - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - (**self).set_kv_slice(key, val) - } - - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - (**self).get_kv_slice(key, buf) - } - - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - (**self).set_kv_u64(key, val) - } - - fn get_kv_u64(&self, key: &str) -> Result { - (**self).get_kv_u64(key) - } - - fn remove(&mut self, key: &str) -> Result<(), Error> { - (**self).remove(key) - } -} - -#[cfg(feature = "nightly")] -pub mod asynch { - use crate::error::Error; - - use super::Psm; - - pub trait AsyncPsm { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error>; - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error>; - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error>; - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result; - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error>; - } - - impl AsyncPsm for &mut T - where - T: AsyncPsm, - { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { - (**self).set_kv_slice(key, val).await - } - - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error> { - (**self).get_kv_slice(key, buf).await - } - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { - (**self).set_kv_u64(key, val).await - } - - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { - (**self).get_kv_u64(key).await - } - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { - (**self).remove(key).await - } - } - - pub struct Asyncify(pub T); - - impl AsyncPsm for Asyncify - where - T: Psm, - { - async fn set_kv_slice<'a>(&'a mut self, key: &'a str, val: &'a [u8]) -> Result<(), Error> { - self.0.set_kv_slice(key, val) - } - - async fn get_kv_slice<'a, 'b>( - &'a self, - key: &'a str, - buf: &'b mut [u8], - ) -> Result<&'b [u8], Error> { - self.0.get_kv_slice(key, buf) - } - - async fn set_kv_u64<'a>(&'a mut self, key: &'a str, val: u64) -> Result<(), Error> { - self.0.set_kv_u64(key, val) - } - - async fn get_kv_u64<'a>(&'a self, key: &'a str) -> Result { - self.0.get_kv_u64(key) - } - - async fn remove<'a>(&'a mut self, key: &'a str) -> Result<(), Error> { - self.0.remove(key) - } - } -} +#[cfg(feature = "std")] +pub use file_psm::*; #[cfg(feature = "std")] -pub mod std { - use std::fs::{self, DirBuilder, File}; +mod file_psm { + use std::fs; use std::io::{Read, Write}; + use std::path::PathBuf; use crate::error::Error; - use super::Psm; - - pub struct FilePsm {} - - const PSM_DIR: &str = "/tmp/matter_psm"; - - macro_rules! psm_path { - ($key:ident) => { - format!("{}/{}", PSM_DIR, $key) - }; + pub struct FilePsm { + dir: PathBuf, } impl FilePsm { - pub fn new() -> Result { - let result = DirBuilder::new().create(PSM_DIR); - if let Err(e) = result { - if e.kind() != std::io::ErrorKind::AlreadyExists { - return Err(e.into()); + pub fn new(dir: PathBuf) -> Result { + fs::create_dir_all(&dir)?; + + Ok(Self { dir }) + } + + pub fn load<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result, Error> { + let path = self.dir.join(key); + + match fs::File::open(path) { + Ok(mut file) => { + let mut offset = 0; + + loop { + if offset == buf.len() { + return Err(Error::NoSpace); + } + + let len = file.read(&mut buf[offset..])?; + + if len == 0 { + break; + } + + offset += len; + } + + Ok(Some(&buf[..offset])) } + Err(_) => Ok(None), } - - Ok(Self {}) } - pub fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(val)?; + pub fn store(&self, key: &str, data: &[u8]) -> Result<(), Error> { + let path = self.dir.join(key); + + let mut file = fs::File::create(path)?; + + file.write_all(data)?; + Ok(()) } - - pub fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - let mut f = File::open(psm_path!(key))?; - let mut offset = 0; - - loop { - let len = f.read(&mut buf[offset..])?; - offset += len; - - if len == 0 { - break; - } - } - - Ok(&buf[..offset]) - } - - pub fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - let mut f = File::create(psm_path!(key))?; - f.write_all(&val.to_be_bytes())?; - Ok(()) - } - - pub fn get_kv_u64(&self, key: &str) -> Result { - let mut f = File::open(psm_path!(key))?; - let mut buf = [0; 8]; - f.read_exact(&mut buf)?; - Ok(u64::from_be_bytes(buf)) - } - - pub fn remove(&self, key: &str) -> Result<(), Error> { - fs::remove_file(psm_path!(key))?; - Ok(()) - } - } - - impl Psm for FilePsm { - fn set_kv_slice(&mut self, key: &str, val: &[u8]) -> Result<(), Error> { - FilePsm::set_kv_slice(self, key, val) - } - - fn get_kv_slice<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - FilePsm::get_kv_slice(self, key, buf) - } - - fn set_kv_u64(&mut self, key: &str, val: u64) -> Result<(), Error> { - FilePsm::set_kv_u64(self, key, val) - } - - fn get_kv_u64(&self, key: &str) -> Result { - FilePsm::get_kv_u64(self, key) - } - - fn remove(&mut self, key: &str) -> Result<(), Error> { - FilePsm::remove(self, key) - } } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 72cfab2..100eb07 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -35,26 +35,21 @@ pub trait FromTLV<'a> { } } -impl<'a, T: Default + FromTLV<'a> + Copy, const N: usize> FromTLV<'a> for [T; N] { +impl<'a, T: FromTLV<'a>, const N: usize> FromTLV<'a> for [T; N] { fn from_tlv(t: &TLVElement<'a>) -> Result where Self: Sized, { t.confirm_array()?; - let mut a: [T; N] = [Default::default(); N]; - let mut index = 0; + + let mut a = heapless::Vec::::new(); if let Some(tlv_iter) = t.enter() { for element in tlv_iter { - if index < N { - a[index] = T::from_tlv(&element)?; - index += 1; - } else { - error!("Received TLV Array with elements larger than current size"); - break; - } + a.push(T::from_tlv(&element)?).map_err(|_| Error::NoSpace)?; } } - Ok(a) + + a.into_array().map_err(|_| Error::Invalid) } } @@ -114,6 +109,8 @@ totlv_for!(i8 u8 u16 u32 u64 bool); // // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - These only have references into the original list +// - heapless::String, Vheapless::ec: Is the owned version of utfstr and ostr, data is cloned into this +// - heapless::String is only partially implemented // // - TLVArray: Is an array of entries, with reference within the original list @@ -165,6 +162,38 @@ impl<'a> ToTLV for OctetStr<'a> { } } +/// Implements the Owned version of Octet String +impl FromTLV<'_> for heapless::Vec { + fn from_tlv(t: &TLVElement) -> Result, Error> { + heapless::Vec::from_slice(t.slice()?).map_err(|_| Error::NoSpace) + } +} + +impl ToTLV for heapless::Vec { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.str16(tag, self.as_slice()) + } +} + +/// Implements the Owned version of UTF String +impl FromTLV<'_> for heapless::String { + fn from_tlv(t: &TLVElement) -> Result, Error> { + let mut string = heapless::String::new(); + + string + .push_str(core::str::from_utf8(t.slice()?)?) + .map_err(|_| Error::NoSpace)?; + + Ok(string) + } +} + +impl ToTLV for heapless::String { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.utf16(tag, self.as_bytes()) + } +} + /// Applies to all the Option<> Processing impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { @@ -259,7 +288,7 @@ impl<'a, T: ToTLV> TLVArray<'a, T> { } } -impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { +impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> { pub fn get_index(&self, index: usize) -> T { for (curr, element) in self.iter().enumerate() { if curr == index { @@ -270,12 +299,12 @@ impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { } } -impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { +impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> { type Item = T; /* Code for going to the next Element */ fn next(&mut self) -> Option { match self { - Self::Slice(s_iter) => s_iter.next().copied(), + Self::Slice(s_iter) => s_iter.next().cloned(), Self::Ptr(p_iter) => { if let Some(tlv_iter) = p_iter.as_mut() { let e = tlv_iter.next(); @@ -294,7 +323,7 @@ impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> where - T: ToTLV + FromTLV<'a> + Copy + PartialEq, + T: ToTLV + FromTLV<'a> + Clone + PartialEq, { fn eq(&self, other: &&[T]) -> bool { let mut iter1 = self.iter(); @@ -313,7 +342,7 @@ where } } -impl<'a, T: FromTLV<'a> + Copy + ToTLV> ToTLV for TLVArray<'a, T> { +impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { tw.start_array(tag_type)?; for a in self.iter() { @@ -340,7 +369,7 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { } } -impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { +impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "TLVArray [")?; let mut first = true;