diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 82bf61b..c15f607 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -27,7 +27,7 @@ use crate::fabric::{Fabric, FabricMgr}; use crate::interaction_model::command::CommandReq; use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib; -use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; +use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::transport::session::SessionMode; use crate::utils::writebuf::WriteBuf; use crate::{cmd_enter, error::*}; @@ -75,6 +75,8 @@ pub enum Commands { CSRResp = 0x05, AddNOC = 0x06, NOCResp = 0x08, + UpdateFabricLabel = 0x09, + RemoveFabric = 0x0a, AddTrustedRootCert = 0x0b, } @@ -183,19 +185,55 @@ impl NocCluster { if self.failsafe.record_add_noc(fab_idx).is_err() { error!("Failed to record NoC in the FailSafe, what to do?"); } + NocCluster::create_nocresponse(cmd_req.resp, NocStatus::Ok, fab_idx, "".to_owned()); + cmd_req.trans.complete(); + Ok(()) + } + fn create_nocresponse( + tw: &mut TLVWriter, + status_code: NocStatus, + fab_idx: u8, + debug_txt: String, + ) { let cmd_data = NocResp { - status_code: NocStatus::Ok as u8, + status_code: status_code as u8, fab_idx, - debug_txt: "".to_owned(), + debug_txt, }; - let resp = ib::InvResp::cmd_new( + let invoke_resp = ib::InvResp::cmd_new( 0, ID, Commands::NOCResp as u16, EncodeValue::Value(&cmd_data), ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); + let _ = invoke_resp.to_tlv(tw, TagType::Anonymous); + } + + fn handle_command_updatefablabel( + &mut self, + cmd_req: &mut CommandReq, + ) -> Result<(), IMStatusCode> { + cmd_enter!("Update Fabric Label"); + let req = UpdateFabricLabelReq::from_tlv(&cmd_req.data) + .map_err(|_| IMStatusCode::InvalidDataType)?; + let label = req + .label + .to_string() + .map_err(|_| IMStatusCode::InvalidDataType)?; + + let (result, fab_idx) = + if let SessionMode::Case(fab_idx) = cmd_req.trans.session.get_session_mode() { + if self.fabric_mgr.set_label(fab_idx, label).is_err() { + (NocStatus::LabelConflict, fab_idx) + } else { + (NocStatus::Ok, fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; + NocCluster::create_nocresponse(cmd_req.resp, result, fab_idx, "".to_string()); cmd_req.trans.complete(); Ok(()) } @@ -203,18 +241,8 @@ impl NocCluster { fn handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { cmd_enter!("AddNOC"); if let Err(e) = self._handle_command_addnoc(cmd_req) { - let cmd_data = NocResp { - status_code: e as u8, - fab_idx: 0, - debug_txt: "".to_owned(), - }; - let invoke_resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::NOCResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); + //TODO: Fab-idx 0? + NocCluster::create_nocresponse(cmd_req.resp, e, 0, "".to_owned()); cmd_req.trans.complete(); } Ok(()) @@ -401,6 +429,7 @@ impl ClusterType for NocCluster { Commands::AddTrustedRootCert => self.handle_command_addtrustedrootcert(cmd_req), Commands::AttReq => self.handle_command_attrequest(cmd_req), Commands::CertChainReq => self.handle_command_certchainrequest(cmd_req), + Commands::UpdateFabricLabel => self.handle_command_updatefablabel(cmd_req), _ => Err(IMStatusCode::UnsupportedCommand), } } @@ -515,6 +544,12 @@ struct CommonReq<'a> { str: OctetStr<'a>, } +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +struct UpdateFabricLabelReq<'a> { + label: UtfStr<'a>, +} + #[derive(FromTLV)] struct CertChainReq { cert_type: u8, diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 97e07bd..1e91617 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -18,7 +18,7 @@ use std::sync::{Arc, Mutex, MutexGuard, RwLock}; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use log::info; +use log::{error, info}; use owning_ref::RwLockReadGuardRef; use crate::{ @@ -45,6 +45,7 @@ const ST_RCA: &str = "rca"; const ST_ICA: &str = "ica"; const ST_NOC: &str = "noc"; const ST_IPK: &str = "ipk"; +const ST_LBL: &str = "label"; const ST_PBKEY: &str = "pubkey"; const ST_PRKEY: &str = "privkey"; @@ -58,6 +59,7 @@ pub struct Fabric { pub icac: Option, pub noc: Cert, pub ipk: KeySet, + label: String, compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN], mdns_service: Option, } @@ -97,6 +99,7 @@ impl Fabric { noc, ipk: KeySet::default(), compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], + label: "".into(), mdns_service: None, }; Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?; @@ -129,6 +132,7 @@ impl Fabric { icac: Some(Cert::default()), noc: Cert::default(), ipk: KeySet::default(), + label: "".into(), compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], mdns_service: None, }) @@ -186,7 +190,7 @@ impl Fabric { vendor_id: self.vendor_id, fabric_id: self.fabric_id, node_id: self.node_id, - label: UtfStr::new(b""), + label: UtfStr(self.label.as_bytes()), fab_idx: Some(fab_idx), } } @@ -206,6 +210,7 @@ impl Fabric { let len = self.noc.as_tlv(&mut key)?; psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?; psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?; + psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?; let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES]; let len = self.key_pair.get_public_key(&mut key)?; @@ -217,7 +222,7 @@ impl Fabric { let key = &key[..len]; psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?; - psm.set_kv_u64(ST_VID, self.vendor_id.into())?; + psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?; Ok(()) } @@ -241,6 +246,13 @@ impl Fabric { let mut ipk = Vec::new(); psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?; + let mut label = Vec::new(); + psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?; + let label = String::from_utf8(label).map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })?; + let mut pub_key = Vec::new(); psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?; let mut priv_key = Vec::new(); @@ -248,16 +260,20 @@ impl Fabric { let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?; let mut vendor_id = 0; - psm.get_kv_u64(ST_VID, &mut vendor_id)?; + psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?; - Fabric::new( + let f = Fabric::new( keypair, root_ca, icac, noc, ipk.as_slice(), vendor_id as u16, - ) + ); + f.map(|mut f| { + f.label = label; + f + }) } } @@ -361,4 +377,28 @@ impl FabricMgr { } Ok(()) } + + pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { + let index = index as usize; + let mut mgr = self.inner.write()?; + if label != "" { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = &mgr.fabrics[i] { + if fabric.label == label { + return Err(Error::Invalid); + } + } + } + } + if let Some(fabric) = &mut mgr.fabrics[index] { + let old = fabric.label.clone(); + fabric.label = label; + let psm = self.psm.lock().unwrap(); + if fabric.store(index, &psm).is_err() { + fabric.label = old; + return Err(Error::StdIoError); + } + } + Ok(()) + } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index d9b7d10..60bf73a 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -119,6 +119,10 @@ impl<'a> UtfStr<'a> { pub fn new(str: &'a [u8]) -> Self { Self(str) } + + pub fn to_string(self) -> Result { + String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) + } } impl<'a> ToTLV for UtfStr<'a> { @@ -127,6 +131,12 @@ impl<'a> ToTLV for UtfStr<'a> { } } +impl<'a> FromTLV<'a> for UtfStr<'a> { + fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { + t.slice().map(UtfStr) + } +} + /// Implements OctetString from the spec #[derive(Debug, Copy, Clone, PartialEq)] pub struct OctetStr<'a>(pub &'a [u8]);