From 974ac4d1d80c28ec5473e25abe25f691e27a13a2 Mon Sep 17 00:00:00 2001 From: imarkov Date: Sat, 29 Apr 2023 19:38:01 +0300 Subject: [PATCH] Optional feature to capture stacktrace on error --- examples/onoff_light/src/dev_att.rs | 4 +- examples/onoff_light/src/main.rs | 73 ++++++++++------- examples/speaker/src/dev_att.rs | 2 +- matter/Cargo.toml | 5 +- matter/src/acl.rs | 18 ++--- matter/src/cert/asn1_writer.rs | 12 +-- matter/src/cert/mod.rs | 51 +++++++----- matter/src/codec/base38.rs | 16 ++-- matter/src/crypto/crypto_dummy.rs | 10 +-- matter/src/crypto/crypto_esp_mbedtls.rs | 10 +-- matter/src/crypto/crypto_mbedtls.rs | 34 ++++---- matter/src/crypto/crypto_openssl.rs | 23 +++--- matter/src/crypto/crypto_rustcrypto.rs | 14 ++-- matter/src/crypto/mod.rs | 15 ++-- matter/src/data_model/cluster_template.rs | 4 +- matter/src/data_model/objects/cluster.rs | 4 +- matter/src/data_model/objects/encoder.rs | 27 +++++-- matter/src/data_model/objects/handler.rs | 20 +++-- matter/src/data_model/objects/privilege.rs | 6 +- .../src/data_model/sdm/admin_commissioning.rs | 2 +- matter/src/data_model/sdm/failsafe.rs | 19 +++-- .../data_model/sdm/general_commissioning.rs | 4 +- matter/src/data_model/sdm/noc.rs | 12 +-- matter/src/data_model/sdm/nw_commissioning.rs | 4 +- .../data_model/system_model/access_control.rs | 8 +- matter/src/error.rs | 80 +++++++++++++++---- matter/src/fabric.rs | 21 ++--- matter/src/group_keys.rs | 7 +- matter/src/interaction_model/core.rs | 36 +++++---- matter/src/interaction_model/messages.rs | 14 ++-- matter/src/mdns.rs | 10 +-- matter/src/pairing/qr.rs | 19 ++--- matter/src/persist.rs | 4 +- matter/src/secure_channel/case.rs | 30 +++---- matter/src/secure_channel/core.rs | 4 +- matter/src/secure_channel/crypto_dummy.rs | 18 ++--- matter/src/secure_channel/crypto_mbedtls.rs | 4 +- matter/src/secure_channel/crypto_openssl.rs | 4 +- matter/src/secure_channel/pake.rs | 14 ++-- matter/src/secure_channel/spake2p.rs | 8 +- matter/src/secure_channel/status_report.rs | 1 + matter/src/tlv/parser.rs | 50 +++++++----- matter/src/tlv/traits.rs | 21 ++--- matter/src/tlv/writer.rs | 2 +- matter/src/transport/exchange.rs | 32 +++++--- matter/src/transport/mgr.rs | 14 ++-- matter/src/transport/mrp.rs | 8 +- matter/src/transport/packet.rs | 16 ++-- matter/src/transport/plain_hdr.rs | 2 +- matter/src/transport/proto_hdr.rs | 8 +- matter/src/transport/session.rs | 10 +-- matter/src/transport/udp.rs | 4 +- matter/src/utils/parsebuf.rs | 4 +- matter/src/utils/writebuf.rs | 12 +-- matter/tests/common/echo_cluster.rs | 10 +-- matter/tests/common/im_engine.rs | 2 +- matter_macro_derive/Cargo.toml | 1 + matter_macro_derive/src/lib.rs | 48 +++++++---- 58 files changed, 531 insertions(+), 384 deletions(-) diff --git a/examples/onoff_light/src/dev_att.rs b/examples/onoff_light/src/dev_att.rs index a16d53f..93fcbd3 100644 --- a/examples/onoff_light/src/dev_att.rs +++ b/examples/onoff_light/src/dev_att.rs @@ -16,7 +16,7 @@ */ use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; -use matter::error::Error; +use matter::error::{Error, ErrorCode}; pub struct HardCodedDevAtt {} @@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt { data.copy_from_slice(src); Ok(src.len()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } } diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index b6d2588..604ffce 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -16,7 +16,9 @@ */ use std::borrow::Borrow; +use std::error::Error; +use log::info; use matter::core::{CommissioningData, Matter}; use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_on_off; @@ -36,8 +38,10 @@ use matter::transport::{ mod dev_att; -fn main() { - env_logger::init(); +fn main() -> Result<(), impl Error> { + env_logger::init_from_env( + env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), + ); // vid/pid should match those in the DAC let dev_info = BasicInfoConfig { @@ -50,35 +54,37 @@ fn main() { device_name: "OnOff Light", }; - //let mut mdns = matter::mdns::astro::AstroMdns::new().unwrap(); - let mut mdns = matter::mdns::libmdns::LibMdns::new().unwrap(); + let mut mdns = matter::mdns::astro::AstroMdns::new()?; + //let mut mdns = matter::mdns::libmdns::LibMdns::new()?; + //let mut mdns = matter::mdns::DummyMdns {}; let matter = Matter::new_default(&dev_info, &mut mdns, matter::transport::udp::MATTER_PORT); let dev_att = dev_att::HardCodedDevAtt::new(); - let psm = persist::FilePsm::new(std::env::temp_dir().join("matter-iot")).unwrap(); + let psm_path = std::env::temp_dir().join("matter-iot"); + info!("Persisting from/to {}", psm_path.display()); + + let psm = persist::FilePsm::new(psm_path)?; let mut buf = [0; 4096]; - if let Some(data) = psm.load("fabrics", &mut buf).unwrap() { - matter.load_fabrics(data).unwrap(); + if let Some(data) = psm.load("acls", &mut buf)? { + matter.load_acls(data)?; } - if let Some(data) = psm.load("acls", &mut buf).unwrap() { - matter.load_acls(data).unwrap(); + if let Some(data) = psm.load("fabrics", &mut buf)? { + matter.load_fabrics(data)?; } - matter - .start::<4096>( - CommissioningData { - // TODO: Hard-coded for now - verifier: VerifierData::new_with_pw(123456, *matter.borrow()), - discriminator: 250, - }, - &mut buf, - ) - .unwrap(); + matter.start::<4096>( + CommissioningData { + // TODO: Hard-coded for now + verifier: VerifierData::new_with_pw(123456, *matter.borrow()), + discriminator: 250, + }, + &mut buf, + )?; let matter = &matter; let dev_att = &dev_att; @@ -86,20 +92,20 @@ fn main() { let mut transport = TransportMgr::new(matter); smol::block_on(async move { - let udp = UdpListener::new().await.unwrap(); + let udp = UdpListener::new().await?; loop { let mut rx_buf = [0; MAX_RX_BUF_SIZE]; let mut tx_buf = [0; MAX_TX_BUF_SIZE]; - let (len, addr) = udp.recv(&mut rx_buf).await.unwrap(); + let (len, addr) = udp.recv(&mut rx_buf).await?; let mut completion = transport.recv(addr, &mut rx_buf[..len], &mut tx_buf); - while let Some(action) = completion.next_action().unwrap() { + while let Some(action) = completion.next_action()? { match action { RecvAction::Send(addr, buf) => { - udp.send(addr, buf).await.unwrap(); + udp.send(addr, buf).await?; } RecvAction::Interact(mut ctx) => { let node = Node { @@ -119,24 +125,29 @@ fn main() { let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - if im.handle(&mut ctx).unwrap() { - if ctx.send().unwrap() { - udp.send(ctx.tx.peer, ctx.tx.as_slice()).await.unwrap(); + if im.handle(&mut ctx)? { + if ctx.send()? { + udp.send(ctx.tx.peer, ctx.tx.as_slice()).await?; } } } } } - if let Some(data) = matter.store_fabrics(&mut buf).unwrap() { - psm.store("fabrics", data).unwrap(); + if let Some(data) = matter.store_fabrics(&mut buf)? { + psm.store("fabrics", data)?; } - if let Some(data) = matter.store_acls(&mut buf).unwrap() { - psm.store("acls", data).unwrap(); + if let Some(data) = matter.store_acls(&mut buf)? { + psm.store("acls", data)?; } } - }); + + #[allow(unreachable_code)] + Ok::<_, matter::error::Error>(()) + })?; + + Ok::<_, matter::error::Error>(()) } fn handler<'a>(matter: &'a Matter<'a>, dev_att: &'a dyn DevAttDataFetcher) -> impl Handler + 'a { diff --git a/examples/speaker/src/dev_att.rs b/examples/speaker/src/dev_att.rs index a16d53f..c0c1030 100644 --- a/examples/speaker/src/dev_att.rs +++ b/examples/speaker/src/dev_att.rs @@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt { data.copy_from_slice(src); Ok(src.len()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } } diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 9f5503b..78e0993 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,8 +15,9 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls"] +default = ["std", "crypto_mbedtls", "backtrace"] std = ["alloc", "env_logger", "chrono", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "smol"] +backtrace = [] alloc = [] nightly = [] crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] @@ -32,7 +33,7 @@ heapless = "0.7.16" num = "0.4" num-derive = "0.3.3" num-traits = "0.2.15" -strum = { version = "0.24", features = ["derive"], default-features = false, no-default-feature = true } +strum = { version = "0.24", features = ["derive"], default-features = false } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } no-std-net = "0.6" subtle = "2.4.1" diff --git a/matter/src/acl.rs b/matter/src/acl.rs index dea592b..77b8e5b 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -19,7 +19,7 @@ use core::{cell::RefCell, fmt::Display}; use crate::{ data_model::objects::{Access, ClusterId, EndptId, Privilege}, - error::Error, + error::{Error, ErrorCode}, fabric, interaction_model::messages::GenericPath, tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, @@ -50,7 +50,7 @@ impl FromTLV<'_> for AuthMode { { num::FromPrimitive::from_u32(t.u32()?) .filter(|a| *a != AuthMode::Invalid) - .ok_or(Error::Invalid) + .ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -112,7 +112,7 @@ impl AccessorSubjects { return Ok(()); } } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } /// Match the match_subject with any of the current subjects @@ -314,7 +314,7 @@ impl AclEntry { .subjects .iter() .position(|s| s.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.subjects[index] = Some(subject); Ok(()) } @@ -328,7 +328,7 @@ impl AclEntry { .targets .iter() .position(|s| s.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.targets[index] = Some(target); Ok(()) } @@ -425,13 +425,13 @@ impl AclMgr { .filter(|a| a.fab_idx == entry.fab_idx) .count(); if cnt >= ENTRIES_PER_FABRIC { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let index = self .entries .iter() .position(|a| a.is_none()) - .ok_or(Error::NoSpace)?; + .ok_or(ErrorCode::NoSpace)?; self.entries[index] = Some(entry); self.changed = true; @@ -503,7 +503,7 @@ impl AclMgr { } pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { - let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; + let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; self.entries = AclEntries::from_tlv(&root)?; self.changed = false; @@ -547,7 +547,7 @@ impl AclMgr { return Ok(entry); } } - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index b6f4ab7..4afd6b6 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -17,7 +17,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::{ - error::Error, + error::{Error, ErrorCode}, utils::epoch::{UtcCalendar, MATTER_EPOCH_SECS}, }; use core::{fmt::Write, time::Duration}; @@ -54,7 +54,7 @@ impl<'a> ASN1Writer<'a> { self.offset += size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn append_tlv(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error> @@ -70,7 +70,7 @@ impl<'a> ASN1Writer<'a> { self.offset += len; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } fn add_compound(&mut self, val: u8) -> Result<(), Error> { @@ -80,7 +80,7 @@ impl<'a> ASN1Writer<'a> { self.depth[self.current_depth] = self.offset; self.current_depth += 1; if self.current_depth >= MAX_DEPTH { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { Ok(()) } @@ -113,7 +113,7 @@ impl<'a> ASN1Writer<'a> { fn end_compound(&mut self) -> Result<(), Error> { if self.current_depth == 0 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } let seq_len = self.get_compound_len(); let write_offset = self.get_length_encoding_offset(); @@ -148,7 +148,7 @@ impl<'a> ASN1Writer<'a> { // This is done with an 0xA2 followed by 2 bytes of actual len 3 } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)? }; Ok(len) } diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index 621b28d..d750db5 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -19,7 +19,7 @@ use core::fmt::{self, Write}; use crate::{ crypto::KeyPair, - error::Error, + error::{Error, ErrorCode}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, utils::{epoch::UtcCalendar, writebuf::WriteBuf}, }; @@ -349,22 +349,22 @@ impl<'a> FromTLV<'a> for DistNames<'a> { let mut d = Self { dn: heapless::Vec::new(), }; - let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; + let iter = t.confirm_list()?.enter().ok_or(ErrorCode::Invalid)?; for t in iter { if let TagType::Context(tag) = t.get_tag() { if let Ok(value) = t.u64() { d.dn.push((tag, DistNameValue::Uint(value))) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } else if let Ok(value) = t.slice() { if tag > PRINTABLE_STR_THRESHOLD { d.dn.push(( tag - PRINTABLE_STR_THRESHOLD, DistNameValue::PrintableStr(value), )) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } else { d.dn.push((tag, DistNameValue::Utf8Str(value))) - .map_err(|_| Error::BufferTooSmall)?; + .map_err(|_| ErrorCode::BufferTooSmall)?; } } } @@ -531,7 +531,7 @@ fn encode_dn_value( } _ => { error!("Invalid encoding"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)? } }, DistNameValue::Utf8Str(v) => { @@ -570,7 +570,9 @@ impl<'a> Cert<'a> { } pub fn get_node_id(&self) -> Result { - self.subject.u64(DnTags::NodeId).ok_or(Error::NoNodeId) + self.subject + .u64(DnTags::NodeId) + .ok_or_else(|| Error::from(ErrorCode::NoNodeId)) } pub fn get_cat_ids(&self, output: &mut [u32]) { @@ -578,7 +580,9 @@ impl<'a> Cert<'a> { } pub fn get_fabric_id(&self) -> Result { - self.subject.u64(DnTags::FabricId).ok_or(Error::NoFabricId) + self.subject + .u64(DnTags::FabricId) + .ok_or_else(|| Error::from(ErrorCode::NoFabricId)) } pub fn get_pubkey(&self) -> &[u8] { @@ -589,7 +593,7 @@ impl<'a> Cert<'a> { if let Some(id) = self.extensions.subj_key_id.as_ref() { Ok(id.0) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -641,7 +645,7 @@ impl<'a> Cert<'a> { w.integer("Serial Num:", self.serial_no.0)?; w.start_seq("Signature Algorithm:")?; - let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { + let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(ErrorCode::Invalid)? { SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256), }; w.oid(str, &oid)?; @@ -660,11 +664,11 @@ impl<'a> Cert<'a> { w.start_seq("")?; w.start_seq("Public Key Algorithm")?; - let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(Error::Invalid)? { + let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(ErrorCode::Invalid)? { PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY), }; w.oid(str, &pub_key)?; - let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(Error::Invalid)? { + let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(ErrorCode::Invalid)? { EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), }; w.oid(str, &curve_id)?; @@ -704,7 +708,7 @@ impl<'a> CertVerifier<'a> { pub fn add_cert(self, parent: &'a Cert) -> Result, Error> { if !self.cert.is_authority(parent)? { - return Err(Error::InvalidAuthKey); + Err(ErrorCode::InvalidAuthKey)?; } let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; let len = self.cert.as_asn1(&mut asn1, self.utc_calendar)?; @@ -761,7 +765,6 @@ mod tests { use log::info; use crate::cert::Cert; - use crate::error::Error; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::utils::writebuf::WriteBuf; @@ -815,31 +818,43 @@ mod tests { #[test] fn test_verify_chain_incomplete() { // The chain doesn't lead up to a self-signed certificate + + use crate::error::ErrorCode; let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); assert_eq!( - Err(Error::InvalidAuthKey), - a.add_cert(&icac).unwrap().finalise() + Err(ErrorCode::InvalidAuthKey), + a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code()) ); } #[cfg(feature = "std")] #[test] fn test_auth_key_chain_incorrect() { + use crate::error::ErrorCode; + let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); - assert_eq!(Err(Error::InvalidAuthKey), a.add_cert(&icac).map(|_| ())); + assert_eq!( + Err(ErrorCode::InvalidAuthKey), + a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + ); } #[cfg(feature = "std")] #[test] fn test_cert_corrupted() { + use crate::error::ErrorCode; + let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let a = noc.verify_chain_start(crate::utils::epoch::sys_utc_calendar); - assert_eq!(Err(Error::InvalidSignature), a.add_cert(&icac).map(|_| ())); + assert_eq!( + Err(ErrorCode::InvalidSignature), + a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + ); } #[test] diff --git a/matter/src/codec/base38.rs b/matter/src/codec/base38.rs index 14114e6..954f862 100644 --- a/matter/src/codec/base38.rs +++ b/matter/src/codec/base38.rs @@ -17,7 +17,7 @@ //! Base38 encoding and decoding functions. -use crate::error::Error; +use crate::error::{Error, ErrorCode}; const BASE38_CHARS: [char; 38] = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', @@ -86,7 +86,7 @@ const RADIX: u32 = BASE38_CHARS.len() as u32; pub fn encode_string(bytes: &[u8]) -> Result, Error> { let mut string = heapless::String::new(); for c in encode(bytes) { - string.push(c).map_err(|_| Error::NoSpace)?; + string.push(c).map_err(|_| ErrorCode::NoSpace)?; } Ok(string) @@ -135,7 +135,7 @@ pub fn decode_vec(base38_str: &str) -> Result impl Iterator> { match decode_char(*c) { Ok(v) => value = value * RADIX + v as u32, Err(err) => { - cerr = Some(err); + cerr = Some(err.code()); break; } } } } else { - cerr = Some(Error::InvalidData) + cerr = Some(ErrorCode::InvalidData) } (0..repeat) .map(move |_| { if let Some(err) = cerr { - Err(err) + Err(err.into()) } else { let byte = (value & 0xff) as u8; @@ -205,12 +205,12 @@ fn decode_base38(chars: &[u8]) -> impl Iterator> { fn decode_char(c: u8) -> Result { if !(45..=90).contains(&c) { - return Err(Error::InvalidData); + Err(ErrorCode::InvalidData)?; } let c = DECODE_BASE38[c as usize - 45]; if c == UNUSED { - return Err(Error::InvalidData); + Err(ErrorCode::InvalidData)?; } Ok(c) diff --git a/matter/src/crypto/crypto_dummy.rs b/matter/src/crypto/crypto_dummy.rs index acdae09..f00cefd 100644 --- a/matter/src/crypto/crypto_dummy.rs +++ b/matter/src/crypto/crypto_dummy.rs @@ -17,7 +17,7 @@ use log::error; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -79,7 +79,7 @@ impl KeyPair { pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { @@ -92,17 +92,17 @@ impl KeyPair { pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/crypto/crypto_esp_mbedtls.rs b/matter/src/crypto/crypto_esp_mbedtls.rs index 4eee8a7..cad046b 100644 --- a/matter/src/crypto/crypto_esp_mbedtls.rs +++ b/matter/src/crypto/crypto_esp_mbedtls.rs @@ -17,7 +17,7 @@ use log::error; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { error!("This API should never get called"); @@ -81,7 +81,7 @@ impl KeyPair { pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result { @@ -94,17 +94,17 @@ impl KeyPair { pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> { error!("This API should never get called"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/crypto/crypto_mbedtls.rs b/matter/src/crypto/crypto_mbedtls.rs index c87e669..3f95d04 100644 --- a/matter/src/crypto/crypto_mbedtls.rs +++ b/matter/src/crypto/crypto_mbedtls.rs @@ -34,7 +34,7 @@ use crate::{ // TODO: We should move ASN1Writer out of Cert, // so Crypto doesn't have to depend on Cert cert::{ASN1Writer, CertConsumer}, - error::Error, + error::{Error, ErrorCode}, }; pub struct HmacSha256 { @@ -49,11 +49,13 @@ impl HmacSha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.inner.update(data).map_err(|_| Error::TLSStack) + self.inner + .update(data) + .map_err(|_| ErrorCode::TLSStack.into()) } pub fn finish(self, out: &mut [u8]) -> Result<(), Error> { - self.inner.finish(out).map_err(|_| Error::TLSStack)?; + self.inner.finish(out).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } } @@ -102,11 +104,11 @@ impl KeyPair { Ok(Some(a)) => Ok(a), Ok(None) => { error!("Error in writing CSR: None received"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } Err(e) => { error!("Error in writing CSR {}", e); - Err(Error::TLSStack) + Err(ErrorCode::TLSStack.into()) } } } @@ -161,7 +163,7 @@ impl KeyPair { let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } safemem::write_bytes(signature, 0); @@ -192,7 +194,7 @@ impl KeyPair { if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { info!("The error is {}", e); - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } else { Ok(()) } @@ -229,7 +231,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { // Type 0x2 is Integer (first integer is r) if signature[offset] != 2 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } offset += 1; @@ -254,7 +256,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { // Type 0x2 is Integer (this integer is s) if signature[offset] != 2 { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } offset += 1; @@ -273,17 +275,17 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result { Ok(64) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { - Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| Error::TLSStack) + Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| ErrorCode::TLSStack.into()) } pub fn encrypt_in_place( @@ -304,7 +306,7 @@ pub fn encrypt_in_place( cipher .encrypt_auth_inplace(ad, data, tag) .map(|(len, _)| len) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn decrypt_in_place( @@ -326,7 +328,7 @@ pub fn decrypt_in_place( .map(|(len, _)| len) .map_err(|e| { error!("Error during decryption: {:?}", e); - Error::TLSStack + ErrorCode::TLSStack.into() }) } @@ -343,12 +345,12 @@ impl Sha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.ctx.update(data).map_err(|_| Error::TLSStack)?; + self.ctx.update(data).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> { - self.ctx.finish(digest).map_err(|_| Error::TLSStack)?; + self.ctx.finish(digest).map_err(|_| ErrorCode::TLSStack)?; Ok(()) } } diff --git a/matter/src/crypto/crypto_openssl.rs b/matter/src/crypto/crypto_openssl.rs index e448619..5343c52 100644 --- a/matter/src/crypto/crypto_openssl.rs +++ b/matter/src/crypto/crypto_openssl.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use foreign_types::ForeignTypeRef; use log::error; @@ -46,7 +46,8 @@ pub struct HmacSha256 { impl HmacSha256 { pub fn new(key: &[u8]) -> Result { Ok(Self { - ctx: Hmac::::new_from_slice(key).map_err(|_x| Error::InvalidKeyLength)?, + ctx: Hmac::::new_from_slice(key) + .map_err(|_x| ErrorCode::InvalidKeyLength)?, }) } @@ -107,7 +108,7 @@ impl KeyPair { fn private_key(&self) -> Result<&EcKey, Error> { match &self.key { - KeyType::Public(_) => Err(Error::Invalid), + KeyType::Public(_) => Err(ErrorCode::Invalid.into()), KeyType::Private(k) => Ok(&k), } } @@ -167,7 +168,7 @@ impl KeyPair { a.copy_from_slice(csr); Ok(a) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -178,7 +179,7 @@ impl KeyPair { let msg = h.finish()?; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } safemem::write_bytes(signature, 0); @@ -205,11 +206,11 @@ impl KeyPair { KeyType::Public(key) => key, _ => { error!("Not yet supported"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } }; if !sig.verify(&msg, k)? { - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } else { Ok(()) } @@ -220,7 +221,7 @@ const P256_KEY_LEN: usize = 256 / 8; pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> { if out_key.len() != P256_KEY_LEN { error!("Insufficient length"); - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { let key = X509::from_der(der)?.public_key()?.public_key_to_der()?; let len = key.len(); @@ -232,7 +233,7 @@ pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Erro pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key) - .map_err(|_e| Error::TLSStack) + .map_err(|_e| ErrorCode::TLSStack.into()) } pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { @@ -372,7 +373,9 @@ impl Sha256 { } pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { - self.hasher.update(data).map_err(|_| Error::TLSStack) + self.hasher + .update(data) + .map_err(|_| ErrorCode::TLSStack.into()) } pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> { diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index f64cbc4..b9aa310 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -39,7 +39,7 @@ use x509_cert::{ spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, }; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use super::CryptoKeyPair; @@ -79,7 +79,7 @@ impl HmacSha256 { Ok(Self { inner: HmacSha256I::new_from_slice(key).map_err(|e| { error!("Error creating HmacSha256 {:?}", e); - Error::TLSStack + ErrorCode::TLSStack })?, }) } @@ -143,7 +143,7 @@ impl KeyPair { fn private_key(&self) -> Result<&SecretKey, Error> { match &self.key { KeyType::Private(key) => Ok(key), - KeyType::Public(_) => Err(Error::Crypto), + KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } } @@ -158,7 +158,7 @@ impl CryptoKeyPair for KeyPair { priv_key[..slice.len()].copy_from_slice(slice); Ok(len) } - KeyType::Public(_) => Err(Error::Crypto), + KeyType::Public(_) => Err(ErrorCode::Crypto.into()), } } fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { @@ -251,7 +251,7 @@ impl CryptoKeyPair for KeyPair { use p256::ecdsa::signature::Signer; if signature.len() < super::EC_SIGNATURE_LEN_BYTES { - return Err(Error::NoSpace); + return Err(ErrorCode::NoSpace.into()); } match &self.key { @@ -274,7 +274,7 @@ impl CryptoKeyPair for KeyPair { verifying_key .verify(msg, &signature) - .map_err(|_| Error::InvalidSignature)?; + .map_err(|_| ErrorCode::InvalidSignature)?; Ok(()) } @@ -291,7 +291,7 @@ pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Resu .expand(info, key) .map_err(|e| { error!("Error with hkdf_sha256 {:?}", e); - Error::TLSStack + ErrorCode::TLSStack.into() }) } diff --git a/matter/src/crypto/mod.rs b/matter/src/crypto/mod.rs index 3b7b4c4..47d49b7 100644 --- a/matter/src/crypto/mod.rs +++ b/matter/src/crypto/mod.rs @@ -15,7 +15,7 @@ * limitations under the License. */ use crate::{ - error::Error, + error::{Error, ErrorCode}, tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; @@ -80,12 +80,12 @@ impl<'a> FromTLV<'a> for KeyPair { t.confirm_array()?.enter(); if let Some(mut array) = t.enter() { - let pub_key = array.next().ok_or(Error::Invalid)?.slice()?; - let priv_key = array.next().ok_or(Error::Invalid)?.slice()?; + let pub_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; + let priv_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; KeyPair::new_from_components(pub_key, priv_key) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } @@ -108,7 +108,7 @@ impl ToTLV for KeyPair { #[cfg(test)] mod tests { - use crate::error::Error; + use crate::error::ErrorCode; use super::KeyPair; @@ -122,8 +122,9 @@ mod tests { fn test_verify_msg_fail() { let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap(); assert_eq!( - key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1), - Err(Error::InvalidSignature) + key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1) + .map_err(|e| e.code()), + Err(ErrorCode::InvalidSignature) ); } diff --git a/matter/src/data_model/cluster_template.rs b/matter/src/data_model/cluster_template.rs index c103812..1e6adb8 100644 --- a/matter/src/data_model/cluster_template.rs +++ b/matter/src/data_model/cluster_template.rs @@ -17,7 +17,7 @@ use crate::{ data_model::objects::{Cluster, Handler}, - error::Error, + error::{Error, ErrorCode}, utils::rand::Rand, }; @@ -51,7 +51,7 @@ impl TemplateCluster { if attr.is_system() { CLUSTER.read(attr.attr_id, writer) } else { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } else { Ok(()) diff --git a/matter/src/data_model/objects/cluster.rs b/matter/src/data_model/objects/cluster.rs index 3818f93..f9f4c5c 100644 --- a/matter/src/data_model/objects/cluster.rs +++ b/matter/src/data_model/objects/cluster.rs @@ -22,7 +22,7 @@ use crate::{ acl::{AccessReq, Accessor}, attribute_enum, data_model::objects::*, - error::Error, + error::{Error, ErrorCode}, interaction_model::{ core::IMStatusCode, messages::{ @@ -320,7 +320,7 @@ impl<'a> Cluster<'a> { GlobalElements::FeatureMap => writer.set(self.feature_map), other => { error!("This attribute is not yet handled {:?}", other); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } diff --git a/matter/src/data_model/objects/encoder.rs b/matter/src/data_model/objects/encoder.rs index d068ce7..e97eea0 100644 --- a/matter/src/data_model/objects/encoder.rs +++ b/matter/src/data_model/objects/encoder.rs @@ -26,7 +26,7 @@ use crate::interaction_model::messages::ib::{ use crate::interaction_model::messages::GenericPath; use crate::tlv::UtfStr; use crate::{ - error::Error, + error::{Error, ErrorCode}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, }; @@ -135,8 +135,13 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { match handler.read(&attr, encoder) { Ok(()) => None, - Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), - Err(error) => attr.status(error.into())?, + Err(e) => { + if e.code() == ErrorCode::NoSpace { + return Ok(Some(attr.path().to_gp())); + } else { + attr.status(e.into())? + } + } } } Err(status) => Some(status), @@ -181,8 +186,13 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { match handler.read(&attr, encoder).await { Ok(()) => None, - Err(Error::NoSpace) => return Ok(Some(attr.path().to_gp())), - Err(error) => attr.status(error.into())?, + Err(e) => { + if e.code() == ErrorCode::NoSpace { + return Ok(Some(attr.path().to_gp())); + } else { + attr.status(e.into())? + } + } } } Err(status) => Some(status), @@ -321,7 +331,7 @@ impl<'a> AttrData<'a> { pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> { if let Some(req_dataver) = self.for_dataver { if req_dataver != dataver { - return Err(Error::DataVersionMismatch); + Err(ErrorCode::DataVersionMismatch)?; } } @@ -557,7 +567,8 @@ macro_rules! attribute_enum { type Error = $crate::error::Error; fn try_from(id: $crate::data_model::objects::AttrId) -> Result { - <$en>::from_repr(id).ok_or($crate::error::Error::AttributeNotFound) + <$en>::from_repr(id) + .ok_or_else(|| $crate::error::ErrorCode::AttributeNotFound.into()) } } }; @@ -571,7 +582,7 @@ macro_rules! command_enum { type Error = $crate::error::Error; fn try_from(id: $crate::data_model::objects::CmdId) -> Result { - <$en>::from_repr(id).ok_or($crate::error::Error::CommandNotFound) + <$en>::from_repr(id).ok_or_else(|| $crate::error::ErrorCode::CommandNotFound.into()) } } }; diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index 7758427..a5e2b9c 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -15,7 +15,11 @@ * limitations under the License. */ -use crate::{error::Error, interaction_model::core::Transaction, tlv::TLVElement}; +use crate::{ + error::{Error, ErrorCode}, + interaction_model::core::Transaction, + tlv::TLVElement, +}; use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}; @@ -27,7 +31,7 @@ pub trait Handler { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>; fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } fn invoke( @@ -37,7 +41,7 @@ pub trait Handler { _data: &TLVElement, _encoder: CmdDataEncoder, ) -> Result<(), Error> { - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } } @@ -88,7 +92,7 @@ impl EmptyHandler { impl Handler for EmptyHandler { fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } @@ -202,7 +206,7 @@ macro_rules! handler_chain_type { pub mod asynch { use crate::{ data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails}, - error::Error, + error::{Error, ErrorCode}, interaction_model::core::Transaction, tlv::TLVElement, }; @@ -221,7 +225,7 @@ pub mod asynch { _attr: &'a AttrDetails<'_>, _data: AttrData<'a>, ) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } async fn invoke<'a>( @@ -231,7 +235,7 @@ pub mod asynch { _data: &'a TLVElement<'_>, _encoder: CmdDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } } @@ -305,7 +309,7 @@ pub mod asynch { _attr: &'a AttrDetails<'_>, _encoder: AttrDataEncoder<'a, '_, '_>, ) -> Result<(), Error> { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } diff --git a/matter/src/data_model/objects/privilege.rs b/matter/src/data_model/objects/privilege.rs index 6b4e3a5..1032a45 100644 --- a/matter/src/data_model/objects/privilege.rs +++ b/matter/src/data_model/objects/privilege.rs @@ -16,7 +16,7 @@ */ use crate::{ - error::Error, + error::{Error, ErrorCode}, tlv::{FromTLV, TLVElement, ToTLV}, }; use log::error; @@ -47,12 +47,12 @@ impl FromTLV<'_> for Privilege { 1 => Ok(Privilege::VIEW), 2 => { error!("ProxyView privilege not yet supporteds"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } 3 => Ok(Privilege::OPERATE), 4 => Ok(Privilege::MANAGE), 5 => Ok(Privilege::ADMIN), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } } diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 5497426..b63aa2e 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -144,7 +144,7 @@ impl<'a> AdminCommCluster<'a> { ) -> Result<(), Error> { match cmd.cmd_id.try_into()? { Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?, - _ => Err(Error::CommandNotFound)?, + _ => Err(ErrorCode::CommandNotFound)?, } self.data_ver.changed(); diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index 5008c9f..301baf9 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -15,7 +15,10 @@ * limitations under the License. */ -use crate::{error::Error, transport::session::SessionMode}; +use crate::{ + error::{Error, ErrorCode}, + transport::session::SessionMode, +}; use log::error; #[derive(PartialEq)] @@ -62,7 +65,7 @@ impl FailSafe { State::Armed(c) => { if c.session_mode != session_mode { error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } // re-arm c.timeout = timeout; @@ -75,22 +78,22 @@ impl FailSafe { match &mut self.state { State::Idle => { error!("Received Fail-Safe Disarm without it being armed"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } State::Armed(c) => { match c.noc_state { - NocState::NocNotRecvd => return Err(Error::Invalid), + NocState::NocNotRecvd => Err(ErrorCode::Invalid)?, NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => { if let SessionMode::Case(c) = session_mode { if c.fab_idx != idx { error!( "Received disarm in separate session from previous Add/Update NOC" ); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } else { error!("Received disarm in a non-CASE session"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } } @@ -106,13 +109,13 @@ impl FailSafe { pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> { match &mut self.state { - State::Idle => Err(Error::Invalid), + State::Idle => Err(ErrorCode::Invalid.into()), State::Armed(c) => { if c.noc_state == NocState::NocNotRecvd { c.noc_state = NocState::AddNocRecvd(fabric_index); Ok(()) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index d4d4329..f2487ef 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -235,9 +235,9 @@ impl<'a> GenCommCluster<'a> { cmd_enter!("Set Regulatory Config"); let country_code = data .find_tag(1) - .map_err(|_| Error::InvalidCommand)? + .map_err(|_| ErrorCode::InvalidCommand)? .slice() - .map_err(|_| Error::InvalidCommand)?; + .map_err(|_| ErrorCode::InvalidCommand)?; info!("Received country code: {:?}", country_code); let cmd_data = CommonResponse { diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 634ba85..6182c0c 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -277,7 +277,7 @@ impl<'a> NocCluster<'a> { } _ => { error!("Attribute not supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } @@ -563,7 +563,7 @@ impl<'a> NocCluster<'a> { info!("Received CSR Nonce:{:?}", req.str); if !self.failsafe.borrow().is_armed() { - return Err(Error::UnsupportedAccess); + Err(ErrorCode::UnsupportedAccess)?; } let noc_keypair = KeyPair::new()?; @@ -602,7 +602,7 @@ impl<'a> NocCluster<'a> { ) -> Result<(), Error> { cmd_enter!("AddTrustedRootCert"); if !self.failsafe.borrow().is_armed() { - return Err(Error::UnsupportedAccess); + Err(ErrorCode::UnsupportedAccess)?; } // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary @@ -612,13 +612,13 @@ impl<'a> NocCluster<'a> { let noc_data = transaction .session_mut() .get_noc_data::() - .ok_or(Error::NoSession)?; + .ok_or(ErrorCode::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| Error::BufferTooSmall)?; + heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; // TODO } _ => (), @@ -720,6 +720,6 @@ fn get_certchainrequest_params(data: &TLVElement) -> Result { match cert_type { CERT_TYPE_DAC => Ok(dev_att::DataType::DAC), CERT_TYPE_PAI => Ok(dev_att::DataType::PAI), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } diff --git a/matter/src/data_model/sdm/nw_commissioning.rs b/matter/src/data_model/sdm/nw_commissioning.rs index 47ffe6e..5abf809 100644 --- a/matter/src/data_model/sdm/nw_commissioning.rs +++ b/matter/src/data_model/sdm/nw_commissioning.rs @@ -20,7 +20,7 @@ use crate::{ AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler, NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP, }, - error::Error, + error::{Error, ErrorCode}, utils::rand::Rand, }; @@ -57,7 +57,7 @@ impl Handler for NwCommCluster { if attr.is_system() { CLUSTER.read(attr.attr_id, writer) } else { - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } else { Ok(()) diff --git a/matter/src/data_model/system_model/access_control.rs b/matter/src/data_model/system_model/access_control.rs index ffba5e6..c57c0df 100644 --- a/matter/src/data_model/system_model/access_control.rs +++ b/matter/src/data_model/system_model/access_control.rs @@ -141,7 +141,7 @@ impl<'a> AccessControlCluster<'a> { } _ => { error!("Attribute not yet supported: this shouldn't happen"); - Err(Error::AttributeNotFound) + Err(ErrorCode::AttributeNotFound.into()) } } } @@ -229,7 +229,7 @@ mod tests { // Test, ACL has fabric index 2, but the accessing fabric is 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1 let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1); - assert_eq!(result, Ok(())); + assert!(result.is_ok()); let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); acl_mgr @@ -268,7 +268,7 @@ mod tests { let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); // Fabric 2's index 1, is actually our index 2, update the verifier verifier[2] = new; - assert_eq!(result, Ok(())); + assert!(result.is_ok()); // Also validate in the acl_mgr that the entries are in the right order let mut index = 0; @@ -301,7 +301,7 @@ mod tests { // Test , Delete Fabric 1's index 0 let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); - assert_eq!(result, Ok(())); + assert!(result.is_ok()); let verifier = [input[0], input[2]]; // Also validate in the acl_mgr that the entries are in the right order diff --git a/matter/src/error.rs b/matter/src/error.rs index d2053d1..507ce4b 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -17,8 +17,8 @@ use core::{array::TryFromSliceError, fmt, str::Utf8Error}; -#[derive(Debug, PartialEq, Clone, Copy)] -pub enum Error { +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ErrorCode { AttributeNotFound, AttributeIsCustom, BufferTooSmall, @@ -73,7 +73,36 @@ pub enum Error { Utf8Fail, } +impl From for Error { + fn from(code: ErrorCode) -> Self { + Self::new(code) + } +} + +pub struct Error { + code: ErrorCode, + #[cfg(all(feature = "std", feature = "backtrace"))] + backtrace: std::backtrace::Backtrace, +} + impl Error { + pub fn new(code: ErrorCode) -> Self { + Self { + code, + #[cfg(all(feature = "std", feature = "backtrace"))] + backtrace: std::backtrace::Backtrace::capture(), + } + } + + pub const fn code(&self) -> ErrorCode { + self.code + } + + #[cfg(all(feature = "std", feature = "backtrace"))] + pub const fn backtrace(&self) -> &std::backtrace::Backtrace { + &self.backtrace + } + pub fn remap(self, matcher: F, to: Self) -> Self where F: FnOnce(&Self) -> bool, @@ -86,19 +115,22 @@ impl Error { } pub fn map_invalid(self, to: Self) -> Self { - self.remap(|e| matches!(e, Self::Invalid | Self::InvalidData), to) + self.remap( + |e| matches!(e.code(), ErrorCode::Invalid | ErrorCode::InvalidData), + to, + ) } pub fn map_invalid_command(self) -> Self { - self.map_invalid(Error::InvalidCommand) + self.map_invalid(Error::new(ErrorCode::InvalidCommand)) } pub fn map_invalid_action(self) -> Self { - self.map_invalid(Error::InvalidAction) + self.map_invalid(Error::new(ErrorCode::InvalidAction)) } pub fn map_invalid_data_type(self) -> Self { - self.map_invalid(Error::InvalidDataType) + self.map_invalid(Error::new(ErrorCode::InvalidDataType)) } } @@ -106,14 +138,14 @@ impl Error { impl From for Error { fn from(_e: std::io::Error) -> Self { // Keep things simple for now - Self::StdIoError + Self::new(ErrorCode::StdIoError) } } #[cfg(feature = "std")] impl From> for Error { fn from(_e: std::sync::PoisonError) -> Self { - Self::RwLock + Self::new(ErrorCode::RwLock) } } @@ -121,7 +153,7 @@ impl From> for Error { impl From for Error { fn from(e: openssl::error::ErrorStack) -> Self { ::log::error!("Error in TLS: {}", e); - Self::TLSStack + Self::new(ErrorCode::TLSStack) } } @@ -129,39 +161,57 @@ impl From for Error { impl From for Error { fn from(e: mbedtls::Error) -> Self { ::log::error!("Error in TLS: {}", e); - Self::TLSStack + Self::new(ErrorCode::TLSStack) } } #[cfg(feature = "crypto_rustcrypto")] impl From for Error { fn from(_e: ccm::aead::Error) -> Self { - Self::Crypto + Self::new(ErrorCode::Crypto) } } #[cfg(feature = "std")] impl From for Error { fn from(_e: std::time::SystemTimeError) -> Self { - Self::SysTimeFail + Error::new(ErrorCode::SysTimeFail) } } impl From for Error { fn from(_e: TryFromSliceError) -> Self { - Self::Invalid + Self::new(ErrorCode::Invalid) } } impl From for Error { fn from(_e: Utf8Error) -> Self { - Self::Utf8Fail + Self::new(ErrorCode::Utf8Fail) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[cfg(not(all(feature = "std", feature = "backtrace")))] + { + write!(f, "Error::{}", self)?; + } + + #[cfg(all(feature = "std", feature = "backtrace"))] + { + write!(f, "Error::{} {{\n", self)?; + write!(f, "{}", self.backtrace())?; + write!(f, "}}\n")?; + } + + Ok(()) } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) + write!(f, "{:?}", self.code()) } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 6f3ff0e..5658d4c 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -24,10 +24,10 @@ use log::info; use crate::{ cert::{Cert, MAX_CERT_TLV_LEN}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, - error::Error, + error::{Error, ErrorCode}, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - tlv::{FromTLV, OctetStr, TLVElement, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -123,7 +123,7 @@ impl Fabric { 0x69, 0x63, ]; hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out) - .map_err(|_| Error::NoSpace) + .map_err(|_| Error::from(ErrorCode::NoSpace)) } pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> { @@ -144,7 +144,7 @@ impl Fabric { if id.as_slice() == target { Ok(()) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -208,7 +208,7 @@ impl FabricMgr { } } - let root = TLVList::new(data).iter().next().ok_or(Error::Invalid)?; + let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; self.fabrics = FabricEntries::from_tlv(&root)?; @@ -227,6 +227,7 @@ impl FabricMgr { if self.changed { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); + self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -252,7 +253,7 @@ impl FabricMgr { } } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { @@ -262,10 +263,10 @@ impl FabricMgr { self.changed = true; Ok(()) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -277,7 +278,7 @@ impl FabricMgr { } } } - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } pub fn get_fabric(&self, idx: usize) -> Result, Error> { @@ -317,7 +318,7 @@ impl FabricMgr { .filter_map(|f| f.as_ref()) .any(|f| f.label == label) { - return Err(Error::Invalid); + return Err(ErrorCode::Invalid.into()); } } diff --git a/matter/src/group_keys.rs b/matter/src/group_keys.rs index d4e9765..7b584e1 100644 --- a/matter/src/group_keys.rs +++ b/matter/src/group_keys.rs @@ -17,8 +17,8 @@ use crate::{ crypto::{self, SYMM_KEY_LEN_BYTES}, - error::Error, - tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + error::{Error, ErrorCode}, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; type KeySetKey = [u8; SYMM_KEY_LEN_BYTES]; @@ -42,7 +42,8 @@ impl KeySet { 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30, ]; - crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey).map_err(|_| Error::NoSpace) + crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey) + .map_err(|_| ErrorCode::NoSpace.into()) } pub fn op_key(&self) -> &[u8] { diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 2ce9c82..9e29bac 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -78,27 +78,33 @@ pub enum IMStatusCode { FailSafeRequired = 0xca, } -impl From for IMStatusCode { - fn from(e: Error) -> Self { +impl From for IMStatusCode { + fn from(e: ErrorCode) -> Self { match e { - Error::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, - Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, - Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, - Error::CommandNotFound => IMStatusCode::UnsupportedCommand, - Error::InvalidAction => IMStatusCode::InvalidAction, - Error::InvalidCommand => IMStatusCode::InvalidCommand, - Error::UnsupportedAccess => IMStatusCode::UnsupportedAccess, - Error::Busy => IMStatusCode::Busy, - Error::DataVersionMismatch => IMStatusCode::DataVersionMismatch, - Error::ResourceExhausted => IMStatusCode::ResourceExhausted, + ErrorCode::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, + ErrorCode::ClusterNotFound => IMStatusCode::UnsupportedCluster, + ErrorCode::AttributeNotFound => IMStatusCode::UnsupportedAttribute, + ErrorCode::CommandNotFound => IMStatusCode::UnsupportedCommand, + ErrorCode::InvalidAction => IMStatusCode::InvalidAction, + ErrorCode::InvalidCommand => IMStatusCode::InvalidCommand, + ErrorCode::UnsupportedAccess => IMStatusCode::UnsupportedAccess, + ErrorCode::Busy => IMStatusCode::Busy, + ErrorCode::DataVersionMismatch => IMStatusCode::DataVersionMismatch, + ErrorCode::ResourceExhausted => IMStatusCode::ResourceExhausted, _ => IMStatusCode::Failure, } } } +impl From for IMStatusCode { + fn from(value: Error) -> Self { + Self::from(value.code()) + } +} + impl FromTLV<'_> for IMStatusCode { fn from_tlv(t: &TLVElement) -> Result { - num::FromPrimitive::from_u16(t.u16()?).ok_or(Error::Invalid) + num::FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into()) } } @@ -223,7 +229,7 @@ pub enum Interaction<'a> { impl<'a> Interaction<'a> { fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { let opcode: OpCode = - num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(Error::Invalid)?; + num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; let rx_data = rx.as_slice(); @@ -264,7 +270,7 @@ impl<'a> Interaction<'a> { )?))), _ => { error!("Opcode not handled: {:?}", opcode); - Err(Error::InvalidOpcode) + Err(ErrorCode::InvalidOpcode.into()) } } } diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index 19e29f1..edf65db 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -17,8 +17,8 @@ use crate::{ data_model::objects::{ClusterId, EndptId}, - error::Error, - tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + error::{Error, ErrorCode}, + tlv::{FromTLV, TLVWriter, TagType, ToTLV}, }; // A generic path with endpoint, clusters, and a leaf @@ -48,7 +48,7 @@ impl GenericPath { cluster: Some(c), leaf: Some(l), } => Ok((e, c, l)), - _ => Err(Error::Invalid), + _ => Err(ErrorCode::Invalid.into()), } } /// Returns true, if the path is wildcard @@ -69,7 +69,7 @@ pub mod msg { use crate::{ error::Error, interaction_model::core::IMStatusCode, - tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, TLVArray, TLVWriter, TagType, ToTLV}, }; use super::ib::{ @@ -259,7 +259,7 @@ pub mod ib { use crate::{ data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId}, - error::Error, + error::{Error, ErrorCode}, interaction_model::core::IMStatusCode, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, }; @@ -447,7 +447,7 @@ pub mod ib { f(ListOperation::DeleteList, data)?; // Now the data must be a list, that should be added item by item - let container = data.enter().ok_or(Error::Invalid)?; + let container = data.enter().ok_or(ErrorCode::Invalid)?; for d in container { f(ListOperation::AddItem, &d)?; } @@ -544,7 +544,7 @@ pub mod ib { if c.path.leaf.is_none() { error!("Wildcard command parameter not supported"); - Err(Error::CommandNotFound) + Err(ErrorCode::CommandNotFound.into()) } else { Ok(c) } diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 1b29618..eb19a9f 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -217,7 +217,7 @@ pub mod astro { use std::collections::HashMap; use super::Mdns; - use crate::error::Error; + use crate::error::{Error, ErrorCode}; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; @@ -269,7 +269,7 @@ pub mod astro { builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); } - let svc = builder.register().map_err(|_| Error::MdnsError)?; + let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; self.services.insert( ServiceId { @@ -348,7 +348,7 @@ pub mod astro { // use std::collections::HashMap; // use super::Mdns; -// use crate::error::Error; +// use crate::error::{Error, ErrorCode}; // use log::info; // use zeroconf::prelude::*; // use zeroconf::{MdnsService, ServiceType, TxtRecord}; @@ -402,7 +402,7 @@ pub mod astro { // svc.set_txt_record(txt); -// //let event_loop = svc.register().map_err(|_| Error::MdnsError)?; +// //let event_loop = svc.register().map_err(|_| ErrorCode::MdnsError)?; // self.services.insert( // ServiceId { @@ -604,7 +604,7 @@ pub mod libmdns { // pub mod simplemdns { // use std::net::Ipv4Addr; -// use crate::error::Error; +// use crate::error::{Error, ErrorCode}; // use super::Mdns; // use log::info; // use simple_dns::{ diff --git a/matter/src/pairing/qr.rs b/matter/src/pairing/qr.rs index e99d909..aa26481 100644 --- a/matter/src/pairing/qr.rs +++ b/matter/src/pairing/qr.rs @@ -16,6 +16,7 @@ */ use crate::{ + error::ErrorCode, tlv::{TLVWriter, TagType}, utils::writebuf::WriteBuf, }; @@ -134,7 +135,7 @@ impl<'data> QrSetupPayload<'data> { if is_vendor_tag(tag) { self.add_optional_data(tag, data) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -150,7 +151,7 @@ impl<'data> QrSetupPayload<'data> { if is_common_tag(tag) { self.add_optional_data(tag, data) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -163,7 +164,7 @@ impl<'data> QrSetupPayload<'data> { } else { self.optional_data.push(item) } - .map_err(|_| Error::NoSpace) + .map_err(|_| ErrorCode::NoSpace.into()) } pub fn get_all_optional_data(&self) -> &[OptionalQRCodeInfo] { @@ -267,7 +268,7 @@ pub(super) fn payload_base38_representation( payload_base38_representation_with_tlv(payload, bits_buf, tlv_buf) } else { - Err(Error::InvalidArgument) + Err(ErrorCode::InvalidArgument.into()) } } @@ -299,7 +300,7 @@ pub fn estimate_buffer_size(payload: &QrSetupPayload) -> Result { estimate = estimate_struct_overhead(estimate); if estimate > u32::MAX as usize { - return Err(Error::NoMemory); + Err(ErrorCode::NoMemory)?; } Ok(estimate) @@ -352,11 +353,11 @@ fn populate_bits( total_payload_data_size_in_bits: usize, ) -> Result<(), Error> { if *offset + number_of_bits > total_payload_data_size_in_bits { - return Err(Error::InvalidArgument); + Err(ErrorCode::InvalidArgument)?; } if input >= 1u64 << number_of_bits { - return Err(Error::InvalidArgument); + Err(ErrorCode::InvalidArgument)?; } let mut index = *offset; @@ -390,7 +391,7 @@ fn payload_base38_representation_with_tlv( let mut base38_encoded: heapless::String = "MT:".into(); for c in base38::encode(bits) { - base38_encoded.push(c).map_err(|_| Error::NoSpace)?; + base38_encoded.push(c).map_err(|_| ErrorCode::NoSpace)?; } Ok(base38_encoded) @@ -431,7 +432,7 @@ fn generate_bit_set<'a>( TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + tlv_data.map(|tlv_data| tlv_data.len() * 8).unwrap_or(0); if bits_buf.len() * 8 < total_payload_size_in_bits { - return Err(Error::BufferTooSmall); + Err(ErrorCode::BufferTooSmall)?; }; let passwd = passwd_from_comm_data(payload.comm_data); diff --git a/matter/src/persist.rs b/matter/src/persist.rs index 53e413e..d9a2733 100644 --- a/matter/src/persist.rs +++ b/matter/src/persist.rs @@ -25,7 +25,7 @@ mod file_psm { use log::info; - use crate::error::Error; + use crate::error::{Error, ErrorCode}; pub struct FilePsm { dir: PathBuf, @@ -47,7 +47,7 @@ mod file_psm { loop { if offset == buf.len() { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let len = file.read(&mut buf[offset..])?; diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index e681ec9..18011c9 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -22,11 +22,11 @@ use log::{error, trace}; use crate::{ cert::Cert, crypto::{self, KeyPair, Sha256}, - error::Error, + error::{Error, ErrorCode}, fabric::{Fabric, FabricMgr}, secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode}, - tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, + tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ network::Address, proto_ctx::ProtoCtx, @@ -90,9 +90,9 @@ impl<'a> Case<'a> { .exch_ctx .exch .take_case_session::() - .ok_or(Error::InvalidState)?; + .ok_or(ErrorCode::InvalidState)?; if case_session.state != State::Sigma1Rx { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } case_session.state = State::Sigma3Rx; @@ -117,7 +117,7 @@ impl<'a> Case<'a> { let mut decrypted: [u8; 800] = [0; 800]; if encrypted.len() > decrypted.len() { error!("Data too large"); - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let decrypted = &mut decrypted[..encrypted.len()]; decrypted.copy_from_slice(encrypted); @@ -204,7 +204,7 @@ impl<'a> Case<'a> { case_session.local_fabric_idx = local_fabric_idx?; if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { error!("Invalid public key length"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0); trace!( @@ -220,7 +220,7 @@ impl<'a> Case<'a> { let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?; if len != 32 { error!("Derived secret length incorrect"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } // println!("Derived secret: {:x?} len: {}", secret, len); @@ -348,14 +348,14 @@ impl<'a> Case<'a> { let mut verifier = noc.verify_chain_start(utc_calendar); if fabric.get_fabric_id() != noc.get_fabric_id()? { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } if let Some(icac) = icac { // If ICAC is present handle it if let Ok(fid) = icac.get_fabric_id() { if fid != fabric.get_fabric_id() { - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } } verifier = verifier.add_cert(icac)?; @@ -377,7 +377,7 @@ impl<'a> Case<'a> { 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73, ]; if key.len() < 48 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -388,7 +388,7 @@ impl<'a> Case<'a> { // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Session Key: key: {:x?}", key); Ok(()) @@ -425,7 +425,7 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33]; if key.len() < 16 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -438,7 +438,7 @@ impl<'a> Case<'a> { // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Sigma3Key: key: {:x?}", key); Ok(()) @@ -452,7 +452,7 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32]; if key.len() < 16 { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } let mut salt = heapless::Vec::::new(); salt.extend_from_slice(ipk).unwrap(); @@ -467,7 +467,7 @@ impl<'a> Case<'a> { // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // println!("Sigma2Key: key: {:x?}", key); Ok(()) diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index fd13206..653ad74 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -57,7 +57,7 @@ impl<'a> SecureChannel<'a> { pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; + num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); @@ -82,7 +82,7 @@ impl<'a> SecureChannel<'a> { OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); - Err(Error::InvalidOpcode) + Err(ErrorCode::InvalidOpcode.into()) } }?; diff --git a/matter/src/secure_channel/crypto_dummy.rs b/matter/src/secure_channel/crypto_dummy.rs index 11ec852..3933e79 100644 --- a/matter/src/secure_channel/crypto_dummy.rs +++ b/matter/src/secure_channel/crypto_dummy.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; #[allow(non_snake_case)] @@ -29,35 +29,35 @@ impl CryptoSpake2 { // Computes w0 from w0s respectively pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] #[allow(dead_code)] pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] pub fn get_pB(&mut self, _pB: &mut [u8]) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } #[allow(non_snake_case)] @@ -68,6 +68,6 @@ impl CryptoSpake2 { _pB: &[u8], _out: &mut [u8], ) -> Result<(), Error> { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index 27c9fc6..de7ea48 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -18,7 +18,7 @@ use alloc::sync::Arc; use core::ops::{Mul, Sub}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -150,7 +150,7 @@ impl CryptoSpake2 { let pB_internal = pB_internal.as_slice(); if pB_internal.len() != pB.len() { error!("pB length mismatch"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } pB.copy_from_slice(pB_internal); Ok(()) diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index 631cb6b..de60fff 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use log::error; @@ -158,7 +158,7 @@ impl CryptoSpake2 { let pB_internal = pB_internal.as_slice(); if pB_internal.len() != pB.len() { error!("pB length mismatch"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } pB.copy_from_slice(pB_internal); Ok(()) diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 1901686..84c5ba0 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -23,10 +23,10 @@ use super::{ }; use crate::{ crypto, - error::Error, + error::{Error, ErrorCode}, mdns::{MdnsMgr, ServiceMode}, secure_channel::common::OpCode, - tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, @@ -176,7 +176,7 @@ impl PakeState { if let PakeState::InProgress(s) = new { Ok(s) } else { - Err(Error::InvalidSignature) + Err(ErrorCode::InvalidSignature.into()) } } @@ -187,7 +187,7 @@ impl PakeState { fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { let sd = self.take()?; if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } else { Ok(sd) } @@ -240,10 +240,10 @@ impl Pake { let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys - let ke = ke.ok_or(Error::Invalid)?; + let ke = ke.ok_or(ErrorCode::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; // Create a session let data = sd.spake2p.get_app_data(); @@ -314,7 +314,7 @@ impl Pake { let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } let mut our_random: [u8; 32] = [0; 32]; diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 8a4b794..9be2d4d 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -25,7 +25,7 @@ use subtle::ConstantTimeEq; use crate::{ crypto::{pbkdf2_hmac, Sha256}, - error::Error, + error::{Error, ErrorCode}, }; use super::{common::SCStatusCodes, crypto::CryptoSpake2}; @@ -198,7 +198,7 @@ impl Spake2P { #[allow(non_snake_case)] pub fn handle_pA(&mut self, pA: &[u8], pB: &mut [u8], cB: &mut [u8]) -> Result<(), Error> { if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) { - return Err(Error::InvalidState); + Err(ErrorCode::InvalidState)?; } if let Some(crypto_spake2) = &mut self.crypto_spake2 { @@ -251,13 +251,13 @@ impl Spake2P { if ke_internal.len() == Ke.len() { Ke.copy_from_slice(ke_internal); } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } // Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys") let mut KcAKcB: [u8; 32] = [0; 32]; crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB) - .map_err(|_x| Error::NoSpace)?; + .map_err(|_x| ErrorCode::NoSpace)?; let KcA = &KcAKcB[0..(KcAKcB.len() / 2)]; let KcB = &KcAKcB[(KcAKcB.len() / 2)..]; diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 477bcfa..2f6aed1 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -46,6 +46,7 @@ pub fn create_status_report( proto_code: u16, proto_data: Option<&[u8]>, ) -> Result<(), Error> { + proto_tx.reset(); proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::StatusReport as u8); let wb = proto_tx.get_writebuf()?; diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index f8b9716..b740f5d 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use byteorder::{ByteOrder, LittleEndian}; use core::fmt; @@ -284,7 +284,7 @@ fn read_length_value<'a>( // We'll consume the current offset (len) + the entire string if length + size_of_length_field > t.left { // Return Error - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { Ok(( // return the additional size only @@ -390,14 +390,14 @@ impl<'a> TLVElement<'a> { pub fn i8(&self) -> Result { match self.element_type { ElementType::S8(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn u8(&self) -> Result { match self.element_type { ElementType::U8(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -405,7 +405,7 @@ impl<'a> TLVElement<'a> { match self.element_type { ElementType::U8(a) => Ok(a.into()), ElementType::U16(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -414,7 +414,7 @@ impl<'a> TLVElement<'a> { ElementType::U8(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()), ElementType::U32(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -424,7 +424,7 @@ impl<'a> TLVElement<'a> { ElementType::U16(a) => Ok(a.into()), ElementType::U32(a) => Ok(a.into()), ElementType::U64(a) => Ok(a), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -434,7 +434,7 @@ impl<'a> TLVElement<'a> { | ElementType::Utf8l(s) | ElementType::Str16l(s) | ElementType::Utf16l(s) => Ok(s), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -444,9 +444,9 @@ impl<'a> TLVElement<'a> { | ElementType::Utf8l(s) | ElementType::Str16l(s) | ElementType::Utf16l(s) => { - Ok(core::str::from_utf8(s).map_err(|_| Error::InvalidData)?) + Ok(core::str::from_utf8(s).map_err(|_| Error::from(ErrorCode::InvalidData))?) } - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } @@ -454,48 +454,48 @@ impl<'a> TLVElement<'a> { match self.element_type { ElementType::False => Ok(false), ElementType::True => Ok(true), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn null(&self) -> Result<(), Error> { match self.element_type { ElementType::Null => Ok(()), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_struct(&self) -> Result, Error> { match self.element_type { ElementType::Struct(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_array(&self) -> Result, Error> { match self.element_type { ElementType::Array(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn confirm_list(&self) -> Result, Error> { match self.element_type { ElementType::List(_) => Ok(*self), - _ => Err(Error::TLVTypeMismatch), + _ => Err(ErrorCode::TLVTypeMismatch.into()), } } pub fn find_tag(&self, tag: u32) -> Result, Error> { let match_tag: TagType = TagType::Context(tag as u8); - let iter = self.enter().ok_or(Error::TLVTypeMismatch)?; + let iter = self.enter().ok_or(ErrorCode::TLVTypeMismatch)?; for a in iter { if match_tag == a.tag_type { return Ok(a); } } - Err(Error::NoTagFound) + Err(ErrorCode::NoTagFound.into()) } pub fn get_tag(&self) -> TagType { @@ -721,14 +721,17 @@ impl<'a> Iterator for TLVContainerIterator<'a> { } pub fn get_root_node(b: &[u8]) -> Result { - TLVList::new(b).iter().next().ok_or(Error::InvalidData) + Ok(TLVList::new(b) + .iter() + .next() + .ok_or(ErrorCode::InvalidData)?) } pub fn get_root_node_struct(b: &[u8]) -> Result { TLVList::new(b) .iter() .next() - .ok_or(Error::InvalidData)? + .ok_or(ErrorCode::InvalidData)? .confirm_struct() } @@ -736,7 +739,7 @@ pub fn get_root_node_list(b: &[u8]) -> Result { TLVList::new(b) .iter() .next() - .ok_or(Error::InvalidData)? + .ok_or(ErrorCode::InvalidData)? .confirm_list() } @@ -802,7 +805,7 @@ mod tests { get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, TagType, }; - use crate::error::Error; + use crate::error::ErrorCode; #[test] fn test_short_length_tag() { @@ -1146,7 +1149,10 @@ mod tests { element_type: ElementType::U32(1), } ); - assert_eq!(cmd_path.find_tag(3), Err(Error::NoTagFound)); + assert_eq!( + cmd_path.find_tag(3).map_err(|e| e.code()), + Err(ErrorCode::NoTagFound) + ); // This is the variable of the invoke command assert_eq!( diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 0311cb3..3fced12 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -16,7 +16,7 @@ */ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use core::fmt::Debug; use core::slice::Iter; use log::error; @@ -31,7 +31,7 @@ pub trait FromTLV<'a> { where Self: Sized, { - Err(Error::TLVNotFound) + Err(ErrorCode::TLVNotFound.into()) } } @@ -45,7 +45,8 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { let mut a = heapless::Vec::::new(); if let Some(tlv_iter) = t.enter() { for element in tlv_iter { - a.push(T::from_tlv(&element)?).map_err(|_| Error::NoSpace)?; + a.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; } } @@ -53,10 +54,10 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { // implementation on top of heapless::Vec (to avoid requiring Copy) // Not sure why we actually need that yet, but without it unit tests fail while a.len() < N { - a.push(Default::default()).map_err(|_| Error::NoSpace)?; + a.push(Default::default()).map_err(|_| ErrorCode::NoSpace)?; } - a.into_array().map_err(|_| Error::Invalid) + a.into_array().map_err(|_| ErrorCode::Invalid.into()) } } @@ -131,7 +132,7 @@ impl<'a> UtfStr<'a> { } pub fn as_str(&self) -> Result<&str, Error> { - core::str::from_utf8(self.0).map_err(|_| Error::Invalid) + core::str::from_utf8(self.0).map_err(|_| ErrorCode::Invalid.into()) } } @@ -172,7 +173,7 @@ impl<'a> ToTLV for OctetStr<'a> { /// Implements the Owned version of Octet String impl FromTLV<'_> for heapless::Vec { fn from_tlv(t: &TLVElement) -> Result, Error> { - heapless::Vec::from_slice(t.slice()?).map_err(|_| Error::NoSpace) + heapless::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) } } @@ -189,7 +190,7 @@ impl FromTLV<'_> for heapless::String { string .push_str(core::str::from_utf8(t.slice()?)?) - .map_err(|_| Error::NoSpace)?; + .map_err(|_| ErrorCode::NoSpace)?; Ok(string) } @@ -411,7 +412,7 @@ impl<'a> ToTLV for TLVElement<'a> { ElementType::EndCnt => tw.end_container(), _ => { error!("ToTLV Not supported"); - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } } @@ -419,7 +420,7 @@ impl<'a> ToTLV for TLVElement<'a> { #[cfg(test)] mod tests { - use super::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; + use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf}; use matter_macro_derive::{FromTLV, ToTLV}; diff --git a/matter/src/tlv/writer.rs b/matter/src/tlv/writer.rs index 1db8421..45c60c9 100644 --- a/matter/src/tlv/writer.rs +++ b/matter/src/tlv/writer.rs @@ -164,7 +164,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { if data.len() > 256 { error!("use str16() instead"); - return Err(Error::Invalid); + return Err(ErrorCode::Invalid.into()); } self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.buf.le_u8(data.len() as u8)?; diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 5d7a79c..5a9bbcf 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -21,11 +21,10 @@ use core::fmt; use core::time::Duration; use log::{error, info, trace}; -use crate::error::Error; +use crate::error::{Error, ErrorCode}; use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq}; use crate::secure_channel; use crate::secure_channel::case::CaseSession; -use crate::tlv::print_tlv_list; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; @@ -227,7 +226,8 @@ impl Exchange { tx.get_proto_id(), tx.get_proto_opcode(), ); - print_tlv_list(tx.as_slice()); + + //print_tlv_list(tx.as_slice()); tx.proto.exch_id = self.id; if self.role == Role::Initiator { @@ -317,10 +317,10 @@ impl ExchangeMgr { info!("Creating new exchange"); let e = Exchange::new(id, sess_idx, role); if exchanges.insert(id, e).is_err() { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } } else { - return Err(Error::NoSpace); + Err(ErrorCode::NoSpace)?; } } @@ -330,11 +330,11 @@ impl ExchangeMgr { if result.get_role() == role && sess_idx == result.sess_idx { Ok(result) } else { - Err(Error::NoExchange) + Err(ErrorCode::NoExchange.into()) } } else { error!("This should never happen"); - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -375,7 +375,7 @@ impl ExchangeMgr { pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result { let exchange = - ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; + ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?; let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); exchange.send(tx, &mut session) } @@ -474,7 +474,7 @@ impl fmt::Display for ExchangeMgr { #[allow(clippy::bool_assert_comparison)] mod tests { use crate::{ - error::Error, + error::ErrorCode, transport::{ network::Address, session::{CloneData, SessionMode}, @@ -532,9 +532,12 @@ mod tests { let clone_data = get_clone_data(peer_sess_id, local_sess_id); match mgr.add_session(&clone_data) { Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), - Err(Error::NoSpace) => break, - _ => { - panic!("Couldn't, create session"); + Err(e) => { + if e.code() == ErrorCode::NoSpace { + break; + } else { + panic!("Could not create sessions"); + } } } local_sess_id += 1; @@ -576,7 +579,10 @@ mod tests { for i in 1..(MAX_SESSIONS + 1) { // Now purposefully overflow the sessions by adding another session let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)); - assert!(matches!(result, Err(Error::NoSpace))); + assert!(matches!( + result.map_err(|e| e.code()), + Err(ErrorCode::NoSpace) + )); let mut buf = [0; MAX_TX_BUF_SIZE]; let tx = &mut Packet::new_tx(&mut buf); diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 0db6390..331c362 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -109,14 +109,18 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } Ok(None) => (RecvState::Ack, None), - Err(Error::Duplicate) => (RecvState::Ack, None), - Err(Error::NoSpace) => (RecvState::EvictSession, None), - Err(err) => Err(err)?, + Err(e) => match e.code() { + ErrorCode::Duplicate => (RecvState::Ack, None), + ErrorCode::NoSpace => (RecvState::EvictSession, None), + _ => Err(e)?, + }, }, RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { Ok(_) => (RecvState::Ack, None), - Err(Error::NoSpace) => (RecvState::EvictSession2(clone_data), None), - Err(err) => Err(err)?, + Err(e) => match e.code() { + ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None), + _ => Err(e)?, + }, }, RecvState::EvictSession => { if self.mgr.exch_mgr.evict_session(&mut self.tx)? { diff --git a/matter/src/transport/mrp.rs b/matter/src/transport/mrp.rs index 2213a52..2d046bf 100644 --- a/matter/src/transport/mrp.rs +++ b/matter/src/transport/mrp.rs @@ -59,7 +59,7 @@ impl AckEntry { ack_timeout, }) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -120,7 +120,7 @@ impl ReliableMessage { if self.retrans.is_some() { // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen error!("Previous retrans entry for this exchange already exists"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr)); @@ -135,7 +135,7 @@ impl ReliableMessage { pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> { if proto_rx.proto.is_ack() { // Handle received Acks - let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; + let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(ErrorCode::Invalid)?; if let Some(entry) = &self.retrans { if entry.get_msg_ctr() != ack_msg_ctr { // TODO: XXX Fix this @@ -150,7 +150,7 @@ impl ReliableMessage { // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK error!("Previous ACK entry for this exchange already exists"); - return Err(Error::Invalid); + Err(ErrorCode::Invalid)?; } self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?); diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index d56485f..3e7e9c7 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -18,7 +18,7 @@ use log::error; use crate::{ - error::Error, + error::{Error, ErrorCode}, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; @@ -109,7 +109,7 @@ impl<'a> Packet<'a> { if let Direction::Rx(pbuf, _) = &mut self.data { Ok(pbuf) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -117,7 +117,7 @@ impl<'a> Packet<'a> { if let Direction::Tx(wbuf) = &mut self.data { Ok(wbuf) } else { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } } @@ -158,10 +158,10 @@ impl<'a> Packet<'a> { .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) } else { error!("Invalid state for proto_decode"); - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } } - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } @@ -171,7 +171,7 @@ impl<'a> Packet<'a> { RxState::Uninit => Ok(false), _ => Ok(true), }, - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } @@ -183,10 +183,10 @@ impl<'a> Packet<'a> { self.plain.decode(pb) } else { error!("Invalid state for plain_decode"); - Err(Error::InvalidState) + Err(ErrorCode::InvalidState.into()) } } - _ => Err(Error::InvalidState), + _ => Err(ErrorCode::InvalidState.into()), } } } diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index e51ddaf..e5a9b24 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -65,7 +65,7 @@ impl PlainHdr { impl PlainHdr { // it will have an additional 'message length' field first pub fn decode(&mut self, msg: &mut ParseBuf) -> Result<(), Error> { - self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(Error::Invalid)?; + self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(ErrorCode::Invalid)?; self.sess_id = msg.le_u16()?; let _sec_flags = msg.le_u8()?; self.sess_type = if self.sess_id != 0 { diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index fd392bd..d7f92fb 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -105,7 +105,7 @@ impl ProtoHdr { decrypt_in_place(plain_hdr.ctr, peer_nodeid, parsebuf, d)?; } - self.exch_flags = ExchFlags::from_bits(parsebuf.le_u8()?).ok_or(Error::Invalid)?; + self.exch_flags = ExchFlags::from_bits(parsebuf.le_u8()?).ok_or(ErrorCode::Invalid)?; self.proto_opcode = parsebuf.le_u8()?; self.exch_id = parsebuf.le_u16()?; self.proto_id = parsebuf.le_u16()?; @@ -128,10 +128,10 @@ impl ProtoHdr { resp_buf.le_u16(self.exch_id)?; resp_buf.le_u16(self.proto_id)?; if self.is_vendor() { - resp_buf.le_u16(self.proto_vendor_id.ok_or(Error::Invalid)?)?; + resp_buf.le_u16(self.proto_vendor_id.ok_or(ErrorCode::Invalid)?)?; } if self.is_ack() { - resp_buf.le_u32(self.ack_msg_ctr.ok_or(Error::Invalid)?)?; + resp_buf.le_u32(self.ack_msg_ctr.ok_or(ErrorCode::Invalid)?)?; } Ok(()) } @@ -216,7 +216,7 @@ fn decrypt_in_place( // If so, we need to handle it cleanly here. aad.copy_from_slice(parsed_slice); } else { - return Err(Error::InvalidAAD); + Err(ErrorCode::InvalidAAD)?; } // IV: diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 95597e2..1135f05 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -410,7 +410,7 @@ impl SessionMgr { self.sessions[index] = Some(session); Ok(index) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -465,7 +465,7 @@ impl SessionMgr { info!("Creating new session"); self.add(peer_addr, peer_nodeid) } else { - Err(Error::NotFound) + Err(ErrorCode::NotFound.into()) } } @@ -484,14 +484,14 @@ impl SessionMgr { let duplicate = session.rx_ctr_state.recv(rx.plain.ctr, is_encrypted); if duplicate { info!("Dropping duplicate packet"); - Err(Error::Duplicate) + Err(ErrorCode::Duplicate.into()) } else { Ok(sess_index) } } pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { - // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; + // let network = self.network.as_ref().ok_or(ErrorCode::NoNetworkInterface)?; // let (len, src) = network.recv(rx.as_borrow_slice()).await?; // rx.get_parsebuf()?.set_len(len); @@ -507,7 +507,7 @@ impl SessionMgr { pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() - .ok_or(Error::NoSession)? + .ok_or(ErrorCode::NoSession)? .do_send(self.epoch, tx)?; // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index b3c4c48..b29ca05 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -53,7 +53,7 @@ impl UdpListener { let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { warn!("Error on the network: {:?}", e); - Error::Network + ErrorCode::Network })?; info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr); @@ -66,7 +66,7 @@ impl UdpListener { Address::Udp(addr) => { let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { warn!("Error on the network: {:?}", e); - Error::Network + ErrorCode::Network })?; info!( diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index d6a8b9a..549e022 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -65,7 +65,7 @@ impl<'a> ParseBuf<'a> { self.left -= size; return Ok(tail); } - Err(Error::TruncatedPacket) + Err(ErrorCode::TruncatedPacket.into()) } fn advance(&mut self, len: usize) { @@ -82,7 +82,7 @@ impl<'a> ParseBuf<'a> { self.advance(size); return Ok(data); } - Err(Error::TruncatedPacket) + Err(ErrorCode::TruncatedPacket.into()) } pub fn le_u8(&mut self) -> Result { diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 3adafe2..21a51e2 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -70,9 +70,9 @@ impl<'a> WriteBuf<'a> { pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { - Err(Error::Invalid) + Err(ErrorCode::Invalid.into()) } else if reserve > self.buf_size { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } else { self.start = reserve; self.end = reserve; @@ -85,7 +85,7 @@ impl<'a> WriteBuf<'a> { self.buf_size -= with; Ok(()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -94,7 +94,7 @@ impl<'a> WriteBuf<'a> { self.buf_size += by; Ok(()) } else { - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } } @@ -107,7 +107,7 @@ impl<'a> WriteBuf<'a> { self.start -= size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn prepend(&mut self, src: &[u8]) -> Result<(), Error> { @@ -126,7 +126,7 @@ impl<'a> WriteBuf<'a> { self.end += size; return Ok(()); } - Err(Error::NoSpace) + Err(ErrorCode::NoSpace.into()) } pub fn append(&mut self, src: &[u8]) -> Result<(), Error> { diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index dd61a0e..e5caca7 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -27,7 +27,7 @@ use matter::{ Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, - error::Error, + error::{Error, ErrorCode}, interaction_model::{ core::Transaction, messages::ib::{attr_list_write, ListOperation}, @@ -122,7 +122,7 @@ impl TestChecker { INIT.call_once(|| { G_TEST_CHECKER = Some(Arc::new(Mutex::new(Self::new()))); }); - Ok(G_TEST_CHECKER.as_ref().ok_or(Error::Invalid)?.clone()) + Ok(G_TEST_CHECKER.as_ref().ok_or(ErrorCode::Invalid)?.clone()) } } } @@ -235,7 +235,7 @@ impl EchoCluster { } } - Err(Error::ResourceExhausted) + Err(ErrorCode::ResourceExhausted.into()) } ListOperation::EditItem(index) => { let data = data.u16()?; @@ -243,7 +243,7 @@ impl EchoCluster { tc.write_list[*index as usize] = Some(data); Ok(()) } else { - Err(Error::InvalidAction) + Err(ErrorCode::InvalidAction.into()) } } ListOperation::DeleteItem(index) => { @@ -251,7 +251,7 @@ impl EchoCluster { tc.write_list[*index as usize] = None; Ok(()) } else { - Err(Error::InvalidAction) + Err(ErrorCode::InvalidAction.into()) } } ListOperation::DeleteList => { diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 86674ec..e140282 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -216,7 +216,7 @@ impl<'a> ImEngine<'a> { epoch: *self.matter.borrow(), }; let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; 1450]; // For the long read tests to run unchanged + let mut tx_buf = [0; 1440]; // For the long read tests to run unchanged let mut rx = Packet::new_rx(&mut rx_buf); let mut tx = Packet::new_tx(&mut tx_buf); // Create fake rx packet diff --git a/matter_macro_derive/Cargo.toml b/matter_macro_derive/Cargo.toml index f0f38ba..163ff50 100644 --- a/matter_macro_derive/Cargo.toml +++ b/matter_macro_derive/Cargo.toml @@ -11,3 +11,4 @@ proc-macro = true syn = { version = "1", features = ["extra-traits"]} quote = "1" proc-macro2 = "1" +proc-macro-crate = "1.3" diff --git a/matter_macro_derive/src/lib.rs b/matter_macro_derive/src/lib.rs index a1fc553..c63eddc 100644 --- a/matter_macro_derive/src/lib.rs +++ b/matter_macro_derive/src/lib.rs @@ -16,7 +16,7 @@ */ use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Ident, Span}; use quote::{format_ident, quote}; use syn::Lit::{Int, Str}; use syn::NestedMeta::{Lit, Meta}; @@ -106,6 +106,18 @@ fn parse_tag_val(field: &syn::Field) -> Option { None } +fn get_crate_name() -> String { + let found_crate = proc_macro_crate::crate_name("matter-iot").unwrap_or_else(|err| { + eprintln!("Warning: defaulting to `crate` {err}"); + proc_macro_crate::FoundCrate::Itself + }); + + match found_crate { + proc_macro_crate::FoundCrate::Itself => String::from("crate"), + proc_macro_crate::FoundCrate::Name(name) => name, + } +} + /// Generate a ToTlv implementation for a structure fn gen_totlv_for_struct( fields: &syn::FieldsNamed, @@ -187,16 +199,18 @@ fn gen_totlv_for_enum( tag_start += 1; } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + let expanded = quote! { - impl #generics ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { + impl #generics #krate::tlv::ToTLV for #enum_name #generics { + fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { tw.start_struct(tag_type)?; match self { #( - Self::#variant_names(c) => { c.to_tlv(tw, TagType::Context(#tags))?; }, + Self::#variant_names(c) => { c.to_tlv(tw, #krate::tlv::TagType::Context(#tags))?; }, )* } tw.end_container() @@ -297,14 +311,16 @@ fn gen_fromtlv_for_struct( } } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + // Currently we don't use find_tag() because the tags come in sequential // order. If ever the tags start coming out of order, we can use find_tag() // instead let expanded = if !tlvargs.unordered { quote! { - impl #generics FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { - let mut t_iter = t.#datatype ()?.enter().ok_or(Error::Invalid)?; + impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { + let mut t_iter = t.#datatype ()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; let mut item = t_iter.next(); #( let #idents = if Some(true) == item.map(|x| x.check_ctx_tag(#tags)) { @@ -324,8 +340,8 @@ fn gen_fromtlv_for_struct( } } else { quote! { - impl #generics FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { + impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { #( let #idents = if let Ok(s) = t.find_tag(#tags as u32) { #types::from_tlv(&s) @@ -375,20 +391,22 @@ fn gen_fromtlv_for_enum( tag_start += 1; } + let krate = Ident::new(&get_crate_name(), Span::call_site()); + let expanded = quote! { - impl #generics FromTLV <#lifetime> for #enum_name #generics { - fn from_tlv(t: &TLVElement<#lifetime>) -> Result { - let mut t_iter = t.confirm_struct()?.enter().ok_or(Error::Invalid)?; - let mut item = t_iter.next().ok_or(Error::Invalid)?; + impl #generics #krate::tlv::FromTLV <#lifetime> for #enum_name #generics { + fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { + let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; + let mut item = t_iter.next().ok_or_else(|| Error::new(#krate::error::ErrorCode::Invalid))?; if let TagType::Context(tag) = item.get_tag() { match tag { #( #tags => Ok(Self::#variant_names(#types::from_tlv(&item)?)), )* - _ => Err(Error::Invalid), + _ => Err(#krate::error::Error::new(#krate::error::ErrorCode::Invalid)), } } else { - Err(Error::TLVTypeMismatch) + Err(#krate::error::Error::new(#krate::error::ErrorCode::TLVTypeMismatch)) } } }