From bcbac965cddac60fa895afad651093e659f207f4 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 06:10:58 +0000 Subject: [PATCH] Remove allocations from Cert handling --- matter/src/cert/asn1_writer.rs | 2 +- matter/src/cert/mod.rs | 148 ++++++++++++++++-------------- matter/src/data_model/sdm/noc.rs | 36 +++++--- matter/src/error.rs | 9 +- matter/src/fabric.rs | 119 +++++++++++++----------- matter/src/secure_channel/case.rs | 16 ++-- matter/src/tlv/traits.rs | 117 ++++++----------------- 7 files changed, 205 insertions(+), 242 deletions(-) diff --git a/matter/src/cert/asn1_writer.rs b/matter/src/cert/asn1_writer.rs index ae2ced8..675546a 100644 --- a/matter/src/cert/asn1_writer.rs +++ b/matter/src/cert/asn1_writer.rs @@ -17,7 +17,7 @@ use super::{CertConsumer, MAX_DEPTH}; use crate::error::Error; -use chrono::{Datelike, TimeZone, Utc}; +use chrono::{Datelike, TimeZone, Utc}; // TODO use core::fmt::Write; use log::warn; diff --git a/matter/src/cert/mod.rs b/matter/src/cert/mod.rs index b928329..757a9d6 100644 --- a/matter/src/cert/mod.rs +++ b/matter/src/cert/mod.rs @@ -15,23 +15,22 @@ * limitations under the License. */ -use core::fmt; - -extern crate alloc; +use core::fmt::{self, Write}; use crate::{ crypto::KeyPair, error::Error, - tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, utils::writebuf::WriteBuf, }; -use alloc::{format, string::String, vec::Vec}; use log::error; use num_derive::FromPrimitive; pub use self::asn1_writer::ASN1Writer; use self::printer::CertPrinter; +pub const MAX_CERT_TLV_LEN: usize = 300; // TODO + // As per https://datatracker.ietf.org/doc/html/rfc5280 const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; @@ -116,8 +115,10 @@ macro_rules! add_if { }; } -fn get_print_str(key_usage: u16) -> String { - format!( +fn get_print_str(key_usage: u16) -> heapless::String<256> { + let mut string = heapless::String::new(); + write!( + &mut string, "{}{}{}{}{}{}{}{}{}", add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), @@ -129,6 +130,9 @@ fn get_print_str(key_usage: u16) -> String { add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), ) + .unwrap(); + + string } #[allow(unused_assignments)] @@ -140,7 +144,7 @@ fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Erro } fn encode_extended_key_usage( - list: &TLVArrayOwned, + list: impl Iterator, w: &mut dyn CertConsumer, ) -> Result<(), Error> { const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; @@ -160,19 +164,18 @@ fn encode_extended_key_usage( ]; w.start_seq("")?; - for t in list.iter() { - let t = *t as usize; + for t in list { + let t = t as usize; if t > 0 && t <= encoding.len() { w.oid(encoding[t].0, encoding[t].1)?; } else { error!("Skipping encoding key usage out of bounds"); } } - w.end_seq()?; - Ok(()) + w.end_seq() } -#[derive(FromTLV, ToTLV, Default)] +#[derive(FromTLV, ToTLV, Default, Debug)] #[tlvargs(start = 1)] struct BasicConstraints { is_ca: bool, @@ -212,18 +215,18 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> { w.end_seq() } -#[derive(FromTLV, ToTLV, Default)] -#[tlvargs(start = 1, datatype = "list")] -struct Extensions { +#[derive(FromTLV, ToTLV, Default, Debug)] +#[tlvargs(lifetime = "'a", start = 1, datatype = "list")] +struct Extensions<'a> { basic_const: Option, key_usage: Option, - ext_key_usage: Option>, - subj_key_id: Option>, - auth_key_id: Option>, - future_extensions: Option>, + ext_key_usage: Option>, + subj_key_id: Option>, + auth_key_id: Option>, + future_extensions: Option>, } -impl Extensions { +impl<'a> Extensions<'a> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; @@ -245,30 +248,29 @@ impl Extensions { } if let Some(t) = &self.ext_key_usage { encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; - encode_extended_key_usage(t, w)?; + encode_extended_key_usage(t.iter(), w)?; encode_extension_end(w)?; } if let Some(t) = &self.subj_key_id { encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; - w.ostr("", t.as_slice())?; + w.ostr("", t.0)?; encode_extension_end(w)?; } if let Some(t) = &self.auth_key_id { encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; w.start_seq("")?; - w.ctx("", 0, t.as_slice())?; + w.ctx("", 0, t.0)?; w.end_seq()?; encode_extension_end(w)?; } if let Some(t) = &self.future_extensions { - error!("Future Extensions Not Yet Supported: {:x?}", t.as_slice()) + error!("Future Extensions Not Yet Supported: {:x?}", t.0); } w.end_seq()?; w.end_ctx()?; Ok(()) } } -const MAX_DN_ENTRIES: usize = 5; #[derive(FromPrimitive, Copy, Clone)] enum DnTags { @@ -296,20 +298,23 @@ enum DnTags { NocCat = 22, } -enum DistNameValue { +#[derive(Debug)] +enum DistNameValue<'a> { Uint(u64), - Utf8Str(Vec), - PrintableStr(Vec), + Utf8Str(&'a [u8]), + PrintableStr(&'a [u8]), } -#[derive(Default)] -struct DistNames { +const MAX_DN_ENTRIES: usize = 5; + +#[derive(Default, Debug)] +struct DistNames<'a> { // The order in which the DNs arrive is important, as the signing // requires that the ASN1 notation retains the same order - dn: Vec<(u8, DistNameValue)>, + dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>, } -impl DistNames { +impl<'a> DistNames<'a> { fn u64(&self, match_id: DnTags) -> Option { self.dn .iter() @@ -339,24 +344,27 @@ impl DistNames { const PRINTABLE_STR_THRESHOLD: u8 = 0x80; -impl<'a> FromTLV<'a> for DistNames { +impl<'a> FromTLV<'a> for DistNames<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result { let mut d = Self { - dn: Vec::with_capacity(MAX_DN_ENTRIES), + dn: heapless::Vec::new(), }; let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; for t in iter { if let TagType::Context(tag) = t.get_tag() { if let Ok(value) = t.u64() { - d.dn.push((tag, DistNameValue::Uint(value))); + d.dn.push((tag, DistNameValue::Uint(value))) + .map_err(|_| Error::BufferTooSmall)?; } else if let Ok(value) = t.slice() { if tag > PRINTABLE_STR_THRESHOLD { d.dn.push(( tag - PRINTABLE_STR_THRESHOLD, - DistNameValue::PrintableStr(value.to_vec()), - )); + DistNameValue::PrintableStr(value), + )) + .map_err(|_| Error::BufferTooSmall)?; } else { - d.dn.push((tag, DistNameValue::Utf8Str(value.to_vec()))); + d.dn.push((tag, DistNameValue::Utf8Str(value))) + .map_err(|_| Error::BufferTooSmall)?; } } } @@ -365,24 +373,23 @@ impl<'a> FromTLV<'a> for DistNames { } } -impl ToTLV for DistNames { +impl<'a> ToTLV for DistNames<'a> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { tw.start_list(tag)?; for (name, value) in &self.dn { match value { DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, - DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v.as_slice())?, - DistNameValue::PrintableStr(v) => tw.utf8( - TagType::Context(*name + PRINTABLE_STR_THRESHOLD), - v.as_slice(), - )?, + DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?, + DistNameValue::PrintableStr(v) => { + tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)? + } } } tw.end_container() } } -impl DistNames { +impl<'a> DistNames<'a> { fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; @@ -520,38 +527,36 @@ fn encode_dn_value( } }, DistNameValue::Utf8Str(v) => { - let str = String::from_utf8(v.to_vec())?; - w.utf8str("", &str)?; + w.utf8str("", core::str::from_utf8(v)?)?; } DistNameValue::PrintableStr(v) => { - let str = String::from_utf8(v.to_vec())?; - w.printstr("", &str)?; + w.printstr("", core::str::from_utf8(v)?)?; } } w.end_seq()?; w.end_set() } -#[derive(FromTLV, ToTLV, Default)] -#[tlvargs(start = 1)] -pub struct Cert { - serial_no: Vec, +#[derive(FromTLV, ToTLV, Default, Debug)] +#[tlvargs(lifetime = "'a", start = 1)] +pub struct Cert<'a> { + serial_no: OctetStr<'a>, sign_algo: u8, - issuer: DistNames, + issuer: DistNames<'a>, not_before: u32, not_after: u32, - subject: DistNames, + subject: DistNames<'a>, pubkey_algo: u8, ec_curve_id: u8, - pubkey: Vec, - extensions: Extensions, - signature: Vec, + pubkey: OctetStr<'a>, + extensions: Extensions<'a>, + signature: OctetStr<'a>, } // TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding // rules in terms of sequence may get complicated. Need to look into this -impl Cert { - pub fn new(cert_bin: &[u8]) -> Result { +impl<'a> Cert<'a> { + pub fn new(cert_bin: &'a [u8]) -> Result { let root = tlv::get_root_node(cert_bin)?; Cert::from_tlv(&root) } @@ -569,17 +574,21 @@ impl Cert { } pub fn get_pubkey(&self) -> &[u8] { - self.pubkey.as_slice() + self.pubkey.0 } pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { - self.extensions.subj_key_id.as_deref().ok_or(Error::Invalid) + if let Some(id) = self.extensions.subj_key_id.as_ref() { + Ok(id.0) + } else { + Err(Error::Invalid) + } } pub fn is_authority(&self, their: &Cert) -> Result { if let Some(our_auth_key) = &self.extensions.auth_key_id { let their_subject = their.get_subject_key_id()?; - if our_auth_key == their_subject { + if our_auth_key.0 == their_subject { Ok(true) } else { Ok(false) @@ -590,7 +599,7 @@ impl Cert { } pub fn get_signature(&self) -> &[u8] { - self.signature.as_slice() + self.signature.0 } pub fn as_tlv(&self, buf: &mut [u8]) -> Result { @@ -617,7 +626,7 @@ impl Cert { w.integer("", &[2])?; w.end_ctx()?; - w.integer("Serial Num:", self.serial_no.as_slice())?; + w.integer("Serial Num:", self.serial_no.0)?; w.start_seq("Signature Algorithm:")?; let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { @@ -647,7 +656,7 @@ impl Cert { w.oid(str, &curve_id)?; w.end_seq()?; - w.bitstr("Public-Key:", false, self.pubkey.as_slice())?; + w.bitstr("Public-Key:", false, self.pubkey.0)?; w.end_seq()?; self.extensions.encode(w)?; @@ -658,7 +667,7 @@ impl Cert { } } -impl fmt::Display for Cert { +impl<'a> fmt::Display for Cert<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut printer = CertPrinter::new(f); let _ = self @@ -670,7 +679,7 @@ impl fmt::Display for Cert { } pub struct CertVerifier<'a> { - cert: &'a Cert, + cert: &'a Cert<'a>, } impl<'a> CertVerifier<'a> { @@ -809,6 +818,7 @@ mod tests { #[test] fn test_tlv_conversions() { + let _ = env_logger::try_init(); let test_input: [&[u8]; 3] = [ &test_vectors::NOC1_SUCCESS, &test_vectors::ICAC1_SUCCESS, diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index acaea50..b2dcee2 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -19,7 +19,7 @@ use core::cell::RefCell; use core::convert::TryInto; use crate::acl::{AclEntry, AclMgr, AuthMode}; -use crate::cert::Cert; +use crate::cert::{Cert, MAX_CERT_TLV_LEN}; use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; @@ -158,14 +158,14 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct NocData { pub key_pair: KeyPair, - pub root_ca: Cert, + pub root_ca: heapless::Vec, } impl NocData { pub fn new(key_pair: KeyPair) -> Self { Self { key_pair, - root_ca: Cert::default(), + root_ca: heapless::Vec::new(), } } } @@ -259,8 +259,10 @@ impl<'a> NocCluster<'a> { writer.start_array(AttrDataWriter::TAG)?; self.fabric_mgr.borrow().for_each(|entry, fab_idx| { if !attr.fab_filter || attr.fab_idx == fab_idx { + let root_ca_cert = entry.get_root_ca()?; + entry - .get_fabric_desc(fab_idx) + .get_fabric_desc(fab_idx, &root_ca_cert)? .to_tlv(&mut writer, TagType::Anonymous)?; } @@ -351,12 +353,18 @@ impl<'a> NocCluster<'a> { let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; - let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received NOC as: {}", noc_value); - let icac_value = if !r.icac_value.0.is_empty() { - let cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received ICAC as: {}", cert); - Some(cert) + let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received NOC as: {}", noc_cert); + + let noc = heapless::Vec::from_slice(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + + let icac = if !r.icac_value.0.is_empty() { + let icac_cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + info!("Received ICAC as: {}", icac_cert); + + let icac = + heapless::Vec::from_slice(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + Some(icac) } else { None }; @@ -364,8 +372,8 @@ impl<'a> NocCluster<'a> { let fabric = Fabric::new( noc_data.key_pair, noc_data.root_ca, - icac_value, - noc_value, + icac, + noc, r.ipk_value.0, r.vendor_id, "", @@ -592,7 +600,9 @@ impl<'a> NocCluster<'a> { let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); - noc_data.root_ca = Cert::new(req.str.0)?; + noc_data.root_ca = + heapless::Vec::from_slice(req.str.0).map_err(|_| Error::BufferTooSmall)?; + // TODO } _ => (), } diff --git a/matter/src/error.rs b/matter/src/error.rs index 3a54b2c..22a04e4 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -use alloc::string::FromUtf8Error; -use core::{array::TryFromSliceError, fmt}; +use core::{array::TryFromSliceError, fmt, str::Utf8Error}; use async_channel::{SendError, TryRecvError}; use log::error; -extern crate alloc; - #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { AttributeNotFound, @@ -166,8 +163,8 @@ impl From> for Error { } } -impl From for Error { - fn from(_e: FromUtf8Error) -> Self { +impl From for Error { + fn from(_e: Utf8Error) -> Self { Self::Utf8Fail } } diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index b7e2425..6c9d389 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -21,7 +21,7 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian}; use log::{error, info}; use crate::{ - cert::Cert, + cert::{Cert, MAX_CERT_TLV_LEN}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::Error, group_keys::KeySet, @@ -30,7 +30,6 @@ use crate::{ tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, }; -const MAX_CERT_TLV_LEN: usize = 300; const COMPRESSED_FABRIC_ID_LEN: usize = 8; macro_rules! fb_key { @@ -72,9 +71,9 @@ pub struct Fabric { fabric_id: u64, vendor_id: u16, key_pair: KeyPair, - pub root_ca: Cert, - pub icac: Option, - pub noc: Cert, + pub root_ca: heapless::Vec, + pub icac: Option>, + pub noc: heapless::Vec, pub ipk: KeySet, label: heapless::String<32>, mdns_service_name: heapless::String<33>, @@ -83,20 +82,25 @@ pub struct Fabric { impl Fabric { pub fn new( key_pair: KeyPair, - root_ca: Cert, - icac: Option, - noc: Cert, + root_ca: heapless::Vec, + icac: Option>, + noc: heapless::Vec, ipk: &[u8], vendor_id: u16, label: &str, ) -> Result { - let node_id = noc.get_node_id()?; - let fabric_id = noc.get_fabric_id()?; + let (node_id, fabric_id) = { + let noc_p = Cert::new(&noc)?; + (noc_p.get_node_id()?, noc_p.get_fabric_id()?) + }; let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; - Fabric::get_compressed_id(root_ca.get_pubkey(), fabric_id, &mut compressed_id)?; - let ipk = KeySet::new(ipk, &compressed_id)?; + let ipk = { + let root_ca_p = Cert::new(&root_ca)?; + Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?; + KeySet::new(ipk, &compressed_id)? + }; let mut mdns_service_name = heapless::String::<33>::new(); for c in compressed_id { @@ -144,7 +148,7 @@ impl Fabric { let mut mac = HmacSha256::new(self.ipk.op_key())?; mac.update(random)?; - mac.update(self.root_ca.get_pubkey())?; + mac.update(self.get_root_ca()?.get_pubkey())?; let mut buf: [u8; 8] = [0; 8]; LittleEndian::write_u64(&mut buf, self.fabric_id); @@ -174,15 +178,25 @@ impl Fabric { self.fabric_id } - pub fn get_fabric_desc(&self, fab_idx: u8) -> FabricDescriptor { - FabricDescriptor { - root_public_key: OctetStr::new(self.root_ca.get_pubkey()), + pub fn get_root_ca(&self) -> Result, Error> { + Cert::new(&self.root_ca) + } + + pub fn get_fabric_desc<'a>( + &'a self, + fab_idx: u8, + root_ca_cert: &'a Cert, + ) -> Result, Error> { + let desc = FabricDescriptor { + root_public_key: OctetStr::new(root_ca_cert.get_pubkey()), vendor_id: self.vendor_id, fabric_id: self.fabric_id, node_id: self.node_id, label: UtfStr(self.label.as_bytes()), fab_idx: Some(fab_idx), - } + }; + + Ok(desc) } fn store(&self, index: usize, mut psm: T) -> Result<(), Error> @@ -191,19 +205,13 @@ impl Fabric { { let mut _kb = heapless::String::<32>::new(); - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)?; + psm.set_kv_slice( + fb_key!(index, ST_ICA, _kb), + self.icac.as_deref().unwrap_or(&[]), + )?; - let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut buf)? - } else { - 0 - }; - psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])?; - - let len = self.noc.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])?; + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)?; psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; @@ -228,18 +236,21 @@ impl Fabric { let mut _kb = heapless::String::<32>::new(); let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let root_ca = psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?; - let root_ca = Cert::new(root_ca)?; + + let root_ca = + heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?) + .unwrap(); let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; let icac = if !icac.is_empty() { - Some(Cert::new(icac)?) + Some(heapless::Vec::from_slice(icac).unwrap()) } else { None }; - let noc = psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?; - let noc = Cert::new(noc)?; + let noc = + heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?) + .unwrap(); let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; let label: heapless::String<32> = core::str::from_utf8(label) @@ -293,21 +304,16 @@ impl Fabric { { let mut _kb = heapless::String::<32>::new(); - let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let len = self.root_ca.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len]) + psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca) .await?; - let len = if let Some(icac) = &self.icac { - icac.as_tlv(&mut buf)? - } else { - 0 - }; - psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len]) - .await?; + psm.set_kv_slice( + fb_key!(index, ST_ICA, _kb), + self.icac.as_deref().unwrap_or(&[]), + ) + .await?; - let len = self.noc.as_tlv(&mut buf)?; - psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len]) + psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc) .await?; psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) .await?; @@ -337,24 +343,27 @@ impl Fabric { let mut _kb = heapless::String::<32>::new(); let mut buf = [0u8; MAX_CERT_TLV_LEN]; - let root_ca = psm - .get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) - .await?; - let root_ca = Cert::new(root_ca)?; + + let root_ca = heapless::Vec::from_slice( + psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) + .await?, + ) + .unwrap(); let icac = psm .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) .await?; let icac = if !icac.is_empty() { - Some(Cert::new(icac)?) + Some(heapless::Vec::from_slice(icac).unwrap()) } else { None }; - let noc = psm - .get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) - .await?; - let noc = Cert::new(noc)?; + let noc = heapless::Vec::from_slice( + psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) + .await?, + ) + .unwrap(); let label = psm .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index f5b9cb0..a722dae 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -349,7 +349,9 @@ impl<'a> Case<'a> { verifier = verifier.add_cert(icac)?; } - verifier.add_cert(&fabric.root_ca)?.finalise()?; + verifier + .add_cert(&Cert::new(&fabric.root_ca)?)? + .finalise()?; Ok(()) } @@ -481,9 +483,9 @@ impl<'a> Case<'a> { let mut write_buf = WriteBuf::new(out); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; - tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; - if let Some(icac_cert) = &fabric.icac { - tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))? + tw.str16(TagType::Context(1), &fabric.noc)?; + if let Some(icac_cert) = fabric.icac.as_ref() { + tw.str16(TagType::Context(2), icac_cert)? }; tw.str8(TagType::Context(3), signature)?; @@ -523,9 +525,9 @@ impl<'a> Case<'a> { let mut write_buf = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut write_buf); tw.start_struct(TagType::Anonymous)?; - tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; - if let Some(icac_cert) = &fabric.icac { - tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))?; + tw.str16(TagType::Context(1), &fabric.noc)?; + if let Some(icac_cert) = fabric.icac.as_deref() { + tw.str16(TagType::Context(2), icac_cert)?; } tw.str8(TagType::Context(3), our_pub_key)?; tw.str8(TagType::Context(4), peer_pub_key)?; diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index c7b5e35..72cfab2 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -17,14 +17,10 @@ use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use crate::error::Error; -use alloc::borrow::ToOwned; -use alloc::{string::String, vec::Vec}; use core::fmt::Debug; use core::slice::Iter; use log::error; -extern crate alloc; - pub trait FromTLV<'a> { fn from_tlv(t: &TLVElement<'a>) -> Result where @@ -118,14 +114,11 @@ totlv_for!(i8 u8 u16 u32 u64 bool); // // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - These only have references into the original list -// - String, Vec: Is the owned version of utfstr and ostr, data is cloned into this -// - String is only partially implemented // // - TLVArray: Is an array of entries, with reference within the original list -// - TLVArrayOwned: Is the owned version of this, data is cloned into this /// Implements UTFString from the spec -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Default)] pub struct UtfStr<'a>(pub &'a [u8]); impl<'a> UtfStr<'a> { @@ -136,10 +129,6 @@ impl<'a> UtfStr<'a> { pub fn as_str(&self) -> Result<&str, Error> { core::str::from_utf8(self.0).map_err(|_| Error::Invalid) } - - pub fn to_string(self) -> Result { - String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) - } } impl<'a> ToTLV for UtfStr<'a> { @@ -155,7 +144,7 @@ impl<'a> FromTLV<'a> for UtfStr<'a> { } /// Implements OctetString from the spec -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Default)] pub struct OctetStr<'a>(pub &'a [u8]); impl<'a> OctetStr<'a> { @@ -176,41 +165,6 @@ impl<'a> ToTLV for OctetStr<'a> { } } -/// Implements the Owned version of Octet String -impl FromTLV<'_> for Vec { - fn from_tlv(t: &TLVElement) -> Result, Error> { - t.slice().map(|x| x.to_owned()) - } -} - -impl ToTLV for Vec { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.str16(tag, self.as_slice()) - } -} - -/// Implements the Owned version of UTF String -impl FromTLV<'_> for String { - fn from_tlv(t: &TLVElement) -> Result { - match t.slice() { - Ok(x) => { - if let Ok(s) = String::from_utf8(x.to_vec()) { - Ok(s) - } else { - Err(Error::Invalid) - } - } - Err(e) => Err(e), - } - } -} - -impl ToTLV for String { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.utf16(tag, self.as_bytes()) - } -} - /// Applies to all the Option<> Processing impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { @@ -279,37 +233,6 @@ impl ToTLV for Nullable { } } -/// Owned version of a TLVArray -pub struct TLVArrayOwned(Vec); -impl<'a, T: FromTLV<'a>> FromTLV<'a> for TLVArrayOwned { - fn from_tlv(t: &TLVElement<'a>) -> Result { - t.confirm_array()?; - let mut vec = Vec::::new(); - if let Some(tlv_iter) = t.enter() { - for element in tlv_iter { - vec.push(T::from_tlv(&element)?); - } - } - Ok(Self(vec)) - } -} - -impl ToTLV for TLVArrayOwned { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.start_array(tag_type)?; - for t in &self.0 { - t.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } -} - -impl TLVArrayOwned { - pub fn iter(&self) -> Iter { - self.0.iter() - } -} - #[derive(Copy, Clone)] pub enum TLVArray<'a, T> { // This is used for the to-tlv path @@ -390,18 +313,23 @@ where } } -impl<'a, T: ToTLV> ToTLV for TLVArray<'a, T> { +impl<'a, T: FromTLV<'a> + Copy + ToTLV> ToTLV for TLVArray<'a, T> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - match *self { - Self::Slice(s) => { - tw.start_array(tag_type)?; - for a in s { - a.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } - Self::Ptr(t) => t.to_tlv(tw, tag_type), + tw.start_array(tag_type)?; + for a in self.iter() { + a.to_tlv(tw, TagType::Anonymous)?; } + tw.end_container() + // match *self { + // Self::Slice(s) => { + // tw.start_array(tag_type)?; + // for a in s { + // a.to_tlv(tw, TagType::Anonymous)?; + // } + // tw.end_container() + // } + // Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV + // } } } @@ -414,10 +342,17 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "TLVArray [")?; + let mut first = true; for i in self.iter() { - writeln!(f, "{:?}", i)?; + if !first { + write!(f, ", ")?; + } + + write!(f, "{:?}", i)?; + first = false; } - writeln!(f) + write!(f, "]") } }