diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 000bc0e..aadf5f7 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -19,17 +19,15 @@ mod dev_att; use matter::core::{self, CommissioningData}; use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::device_types::device_type_add_on_off_light; -use rand::prelude::*; +use matter::secure_channel::spake2p::VerifierData; fn main() { env_logger::init(); - let mut comm_data = CommissioningData { + let comm_data = CommissioningData { // TODO: Hard-coded for now - passwd: 123456, + verifier: VerifierData::new_with_pw(123456), discriminator: 250, - ..Default::default() }; - rand::thread_rng().fill_bytes(&mut comm_data.salt); // vid/pid should match those in the DAC let dev_info = BasicInfoConfig { diff --git a/matter/src/core.rs b/matter/src/core.rs index 97143d1..35b8648 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -31,14 +31,10 @@ use crate::{ }; use std::sync::Arc; -#[derive(Default)] /// Device Commissioning Data pub struct CommissioningData { - /// The commissioning salt - pub salt: [u8; 16], - /// The password for commissioning the device - // TODO: We should replace this with verifier instead of password - pub passwd: u32, + /// The data like password or verifier that is required to authenticate + pub verifier: VerifierData, /// The 12-bit discriminator used to differentiate between multiple devices pub discriminator: u16, } @@ -74,8 +70,10 @@ impl Matter { let fabric_mgr = Arc::new(FabricMgr::new()?); let acl_mgr = Arc::new(AclMgr::new()?); + let mut pase = PaseMgr::new(); let open_comm_window = fabric_mgr.is_empty(); - let data_model = DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr)?; + let data_model = + DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; let mut matter = Box::new(Matter { transport_mgr: transport::mgr::Mgr::new()?, data_model, @@ -84,11 +82,12 @@ impl Matter { let interaction_model = Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); matter.transport_mgr.register_protocol(interaction_model)?; - let mut secure_channel = Box::new(SecureChannel::new(matter.fabric_mgr.clone())); + if open_comm_window { - secure_channel.open_comm_window(&dev_comm.salt, dev_comm.passwd)?; + pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; } + let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); matter.transport_mgr.register_protocol(secure_channel)?; Ok(matter) } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index f269cdb..11bf060 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -36,6 +36,7 @@ use crate::{ }, InteractionConsumer, Transaction, }, + secure_channel::pake::PaseMgr, tlv::{TLVArray, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode}, }; @@ -54,6 +55,7 @@ impl DataModel { dev_att: Box, fabric_mgr: Arc, acl_mgr: Arc, + pase_mgr: PaseMgr, ) -> Result { let dm = DataModel { node: Arc::new(RwLock::new(Node::new()?)), @@ -62,7 +64,14 @@ impl DataModel { { let mut node = dm.node.write()?; node.set_changes_cb(Box::new(dm.clone())); - device_type_add_root_node(&mut node, dev_details, dev_att, fabric_mgr, acl_mgr)?; + device_type_add_root_node( + &mut node, + dev_details, + dev_att, + fabric_mgr, + acl_mgr, + pase_mgr, + )?; } Ok(dm) } diff --git a/matter/src/data_model/device_types.rs b/matter/src/data_model/device_types.rs index 47fc022..a0c9c28 100644 --- a/matter/src/data_model/device_types.rs +++ b/matter/src/data_model/device_types.rs @@ -19,6 +19,7 @@ use super::cluster_basic_information::BasicInfoCluster; use super::cluster_basic_information::BasicInfoConfig; use super::cluster_on_off::OnOffCluster; use super::objects::*; +use super::sdm::admin_commissioning::AdminCommCluster; use super::sdm::dev_att::DevAttDataFetcher; use super::sdm::general_commissioning::GenCommCluster; use super::sdm::noc::NocCluster; @@ -27,6 +28,7 @@ use super::system_model::access_control::AccessControlCluster; use crate::acl::AclMgr; use crate::error::*; use crate::fabric::FabricMgr; +use crate::secure_channel::pake::PaseMgr; use std::sync::Arc; use std::sync::RwLockWriteGuard; @@ -38,6 +40,7 @@ pub fn device_type_add_root_node( dev_att: Box, fabric_mgr: Arc, acl_mgr: Arc, + pase_mgr: PaseMgr, ) -> Result { // Add the root endpoint let endpoint = node.add_endpoint()?; @@ -51,6 +54,7 @@ pub fn device_type_add_root_node( let failsafe = general_commissioning.failsafe(); node.add_cluster(0, general_commissioning)?; node.add_cluster(0, NwCommCluster::new()?)?; + node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?; node.add_cluster( 0, NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?, diff --git a/matter/src/data_model/objects/endpoint.rs b/matter/src/data_model/objects/endpoint.rs index 220119b..a87f887 100644 --- a/matter/src/data_model/objects/endpoint.rs +++ b/matter/src/data_model/objects/endpoint.rs @@ -19,7 +19,7 @@ use crate::{data_model::objects::ClusterType, error::*, interaction_model::core: use std::fmt; -pub const CLUSTERS_PER_ENDPT: usize = 7; +pub const CLUSTERS_PER_ENDPT: usize = 9; pub struct Endpoint { clusters: Vec>, diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs new file mode 100644 index 0000000..af8149b --- /dev/null +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -0,0 +1,160 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::cmd_enter; +use crate::data_model::objects::*; +use crate::interaction_model::core::IMStatusCode; +use crate::secure_channel::pake::PaseMgr; +use crate::secure_channel::spake2p::VerifierData; +use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; +use crate::{error::*, interaction_model::command::CommandReq}; +use log::{error, info}; +use num_derive::FromPrimitive; + +pub const ID: u32 = 0x003C; + +#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)] +pub enum WindowStatus { + WindowNotOpen = 0, + EnhancedWindowOpen = 1, + BasicWindowOpen = 2, +} + +#[derive(FromPrimitive)] +pub enum Attributes { + WindowStatus = 0, + AdminFabricIndex = 1, + AdminVendorId = 2, +} + +#[derive(FromPrimitive)] +pub enum Commands { + OpenCommWindow = 0x00, + OpenBasicCommWindow = 0x01, + RevokeComm = 0x02, +} + +fn attr_window_status_new() -> Result { + Attribute::new( + Attributes::WindowStatus as u16, + AttrValue::Custom, + Access::RV, + Quality::NONE, + ) +} + +fn attr_admin_fabid_new() -> Result { + Attribute::new( + Attributes::AdminFabricIndex as u16, + AttrValue::Custom, + Access::RV, + Quality::NULLABLE, + ) +} + +fn attr_admin_vid_new() -> Result { + Attribute::new( + Attributes::AdminVendorId as u16, + AttrValue::Custom, + Access::RV, + Quality::NULLABLE, + ) +} + +pub struct AdminCommCluster { + pase_mgr: PaseMgr, + base: Cluster, +} + +impl ClusterType for AdminCommCluster { + fn base(&self) -> &Cluster { + &self.base + } + fn base_mut(&mut self) -> &mut Cluster { + &mut self.base + } + + fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { + match num::FromPrimitive::from_u16(attr.attr_id) { + Some(Attributes::WindowStatus) => { + let status = 1_u8; + encoder.encode(EncodeValue::Value(&status)) + } + Some(Attributes::AdminVendorId) => { + let vid = Nullable::NotNull(1_u8); + + encoder.encode(EncodeValue::Value(&vid)) + } + Some(Attributes::AdminFabricIndex) => { + let vid = Nullable::NotNull(1_u8); + encoder.encode(EncodeValue::Value(&vid)) + } + _ => { + error!("Unsupported Attribute: this shouldn't happen"); + } + } + } + fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { + let cmd = cmd_req + .cmd + .path + .leaf + .map(num::FromPrimitive::from_u32) + .ok_or(IMStatusCode::UnsupportedCommand)? + .ok_or(IMStatusCode::UnsupportedCommand)?; + match cmd { + Commands::OpenCommWindow => self.handle_command_opencomm_win(cmd_req), + _ => Err(IMStatusCode::UnsupportedCommand), + } + } +} + +impl AdminCommCluster { + pub fn new(pase_mgr: PaseMgr) -> Result, Error> { + let mut c = Box::new(AdminCommCluster { + pase_mgr, + base: Cluster::new(ID)?, + }); + c.base.add_attribute(attr_window_status_new()?)?; + c.base.add_attribute(attr_admin_fabid_new()?)?; + c.base.add_attribute(attr_admin_vid_new()?)?; + Ok(c) + } + + fn handle_command_opencomm_win( + &mut self, + cmd_req: &mut CommandReq, + ) -> Result<(), IMStatusCode> { + cmd_enter!("Open Commissioning Window"); + let req = + OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; + let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); + self.pase_mgr + .enable_pase_session(verifier, req.discriminator)?; + Err(IMStatusCode::Sucess) + } +} + +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +pub struct OpenCommWindowReq<'a> { + _timeout: u16, + verifier: OctetStr<'a>, + discriminator: u16, + iterations: u32, + salt: OctetStr<'a>, +} diff --git a/matter/src/data_model/sdm/mod.rs b/matter/src/data_model/sdm/mod.rs index cec166d..1ce25ad 100644 --- a/matter/src/data_model/sdm/mod.rs +++ b/matter/src/data_model/sdm/mod.rs @@ -15,6 +15,7 @@ * limitations under the License. */ +pub mod admin_commissioning; pub mod dev_att; pub mod failsafe; pub mod general_commissioning; diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index 82bf61b..c15f607 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -27,7 +27,7 @@ use crate::fabric::{Fabric, FabricMgr}; use crate::interaction_model::command::CommandReq; use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib; -use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; +use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::transport::session::SessionMode; use crate::utils::writebuf::WriteBuf; use crate::{cmd_enter, error::*}; @@ -75,6 +75,8 @@ pub enum Commands { CSRResp = 0x05, AddNOC = 0x06, NOCResp = 0x08, + UpdateFabricLabel = 0x09, + RemoveFabric = 0x0a, AddTrustedRootCert = 0x0b, } @@ -183,19 +185,55 @@ impl NocCluster { if self.failsafe.record_add_noc(fab_idx).is_err() { error!("Failed to record NoC in the FailSafe, what to do?"); } + NocCluster::create_nocresponse(cmd_req.resp, NocStatus::Ok, fab_idx, "".to_owned()); + cmd_req.trans.complete(); + Ok(()) + } + fn create_nocresponse( + tw: &mut TLVWriter, + status_code: NocStatus, + fab_idx: u8, + debug_txt: String, + ) { let cmd_data = NocResp { - status_code: NocStatus::Ok as u8, + status_code: status_code as u8, fab_idx, - debug_txt: "".to_owned(), + debug_txt, }; - let resp = ib::InvResp::cmd_new( + let invoke_resp = ib::InvResp::cmd_new( 0, ID, Commands::NOCResp as u16, EncodeValue::Value(&cmd_data), ); - let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous); + let _ = invoke_resp.to_tlv(tw, TagType::Anonymous); + } + + fn handle_command_updatefablabel( + &mut self, + cmd_req: &mut CommandReq, + ) -> Result<(), IMStatusCode> { + cmd_enter!("Update Fabric Label"); + let req = UpdateFabricLabelReq::from_tlv(&cmd_req.data) + .map_err(|_| IMStatusCode::InvalidDataType)?; + let label = req + .label + .to_string() + .map_err(|_| IMStatusCode::InvalidDataType)?; + + let (result, fab_idx) = + if let SessionMode::Case(fab_idx) = cmd_req.trans.session.get_session_mode() { + if self.fabric_mgr.set_label(fab_idx, label).is_err() { + (NocStatus::LabelConflict, fab_idx) + } else { + (NocStatus::Ok, fab_idx) + } + } else { + // Update Fabric Label not allowed + (NocStatus::InvalidFabricIndex, 0) + }; + NocCluster::create_nocresponse(cmd_req.resp, result, fab_idx, "".to_string()); cmd_req.trans.complete(); Ok(()) } @@ -203,18 +241,8 @@ impl NocCluster { fn handle_command_addnoc(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { cmd_enter!("AddNOC"); if let Err(e) = self._handle_command_addnoc(cmd_req) { - let cmd_data = NocResp { - status_code: e as u8, - fab_idx: 0, - debug_txt: "".to_owned(), - }; - let invoke_resp = ib::InvResp::cmd_new( - 0, - ID, - Commands::NOCResp as u16, - EncodeValue::Value(&cmd_data), - ); - let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous); + //TODO: Fab-idx 0? + NocCluster::create_nocresponse(cmd_req.resp, e, 0, "".to_owned()); cmd_req.trans.complete(); } Ok(()) @@ -401,6 +429,7 @@ impl ClusterType for NocCluster { Commands::AddTrustedRootCert => self.handle_command_addtrustedrootcert(cmd_req), Commands::AttReq => self.handle_command_attrequest(cmd_req), Commands::CertChainReq => self.handle_command_certchainrequest(cmd_req), + Commands::UpdateFabricLabel => self.handle_command_updatefablabel(cmd_req), _ => Err(IMStatusCode::UnsupportedCommand), } } @@ -515,6 +544,12 @@ struct CommonReq<'a> { str: OctetStr<'a>, } +#[derive(FromTLV)] +#[tlvargs(lifetime = "'a")] +struct UpdateFabricLabelReq<'a> { + label: UtfStr<'a>, +} + #[derive(FromTLV)] struct CertChainReq { cert_type: u8, diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 97e07bd..1e91617 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -18,7 +18,7 @@ use std::sync::{Arc, Mutex, MutexGuard, RwLock}; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use log::info; +use log::{error, info}; use owning_ref::RwLockReadGuardRef; use crate::{ @@ -45,6 +45,7 @@ const ST_RCA: &str = "rca"; const ST_ICA: &str = "ica"; const ST_NOC: &str = "noc"; const ST_IPK: &str = "ipk"; +const ST_LBL: &str = "label"; const ST_PBKEY: &str = "pubkey"; const ST_PRKEY: &str = "privkey"; @@ -58,6 +59,7 @@ pub struct Fabric { pub icac: Option, pub noc: Cert, pub ipk: KeySet, + label: String, compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN], mdns_service: Option, } @@ -97,6 +99,7 @@ impl Fabric { noc, ipk: KeySet::default(), compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], + label: "".into(), mdns_service: None, }; Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?; @@ -129,6 +132,7 @@ impl Fabric { icac: Some(Cert::default()), noc: Cert::default(), ipk: KeySet::default(), + label: "".into(), compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], mdns_service: None, }) @@ -186,7 +190,7 @@ impl Fabric { vendor_id: self.vendor_id, fabric_id: self.fabric_id, node_id: self.node_id, - label: UtfStr::new(b""), + label: UtfStr(self.label.as_bytes()), fab_idx: Some(fab_idx), } } @@ -206,6 +210,7 @@ impl Fabric { let len = self.noc.as_tlv(&mut key)?; psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?; psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?; + psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?; let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES]; let len = self.key_pair.get_public_key(&mut key)?; @@ -217,7 +222,7 @@ impl Fabric { let key = &key[..len]; psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?; - psm.set_kv_u64(ST_VID, self.vendor_id.into())?; + psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?; Ok(()) } @@ -241,6 +246,13 @@ impl Fabric { let mut ipk = Vec::new(); psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?; + let mut label = Vec::new(); + psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?; + let label = String::from_utf8(label).map_err(|_| { + error!("Couldn't read label"); + Error::Invalid + })?; + let mut pub_key = Vec::new(); psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?; let mut priv_key = Vec::new(); @@ -248,16 +260,20 @@ impl Fabric { let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?; let mut vendor_id = 0; - psm.get_kv_u64(ST_VID, &mut vendor_id)?; + psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?; - Fabric::new( + let f = Fabric::new( keypair, root_ca, icac, noc, ipk.as_slice(), vendor_id as u16, - ) + ); + f.map(|mut f| { + f.label = label; + f + }) } } @@ -361,4 +377,28 @@ impl FabricMgr { } Ok(()) } + + pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { + let index = index as usize; + let mut mgr = self.inner.write()?; + if label != "" { + for i in 1..MAX_SUPPORTED_FABRICS { + if let Some(fabric) = &mgr.fabrics[i] { + if fabric.label == label { + return Err(Error::Invalid); + } + } + } + } + if let Some(fabric) = &mut mgr.fabrics[index] { + let old = fabric.label.clone(); + fabric.label = label; + let psm = self.psm.lock().unwrap(); + if fabric.store(index, &psm).is_err() { + fabric.label = old; + return Err(Error::StdIoError); + } + } + Ok(()) + } } diff --git a/matter/src/lib.rs b/matter/src/lib.rs index d96d9d7..cd7aafc 100644 --- a/matter/src/lib.rs +++ b/matter/src/lib.rs @@ -27,7 +27,7 @@ //! use matter::{Matter, CommissioningData}; //! use matter::data_model::device_types::device_type_add_on_off_light; //! use matter::data_model::cluster_basic_information::BasicInfoConfig; -//! use rand::prelude::*; +//! use matter::secure_channel::spake2p::VerifierData; //! //! # use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; //! # use matter::error::Error; @@ -38,12 +38,11 @@ //! # let dev_att = Box::new(DevAtt{}); //! //! /// The commissioning data for this device -//! let mut comm_data = CommissioningData { -//! passwd: 123456, +//! let comm_data = CommissioningData { +//! verifier: VerifierData::new_with_pw(123456), //! discriminator: 250, -//! ..Default::default() +//! //! }; -//! rand::thread_rng().fill_bytes(&mut comm_data.salt); //! //! /// The basic information about this device //! let dev_info = BasicInfoConfig { diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 194dcc2..6f66db5 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -47,8 +47,10 @@ static mut G_MDNS: Option> = None; static INIT: Once = Once::new(); pub enum ServiceMode { + /// The commissioned state Commissioned, - Commissionable, + /// The commissionable state with the discriminator that should be used + Commissionable(u16), } impl Mdns { diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 75d1fc9..33c8e47 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -26,12 +26,12 @@ use crate::{ crypto::{self, CryptoKeyPair, KeyPair, Sha256}, error::Error, fabric::{Fabric, FabricMgr, FabricMgrInner}, - secure_channel::common, secure_channel::common::SCStatusCodes, + secure_channel::common::{self, OpCode}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, transport::{ network::Address, - proto_demux::ProtoCtx, + proto_demux::{ProtoCtx, ResponseRequired}, queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, @@ -78,7 +78,7 @@ impl Case { Self { fabric_mgr } } - pub fn handle_casesigma3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { + pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { let mut case_session = ctx .exch_ctx .exch @@ -97,7 +97,7 @@ impl Case { None, )?; ctx.exch_ctx.exch.close(); - return Ok(()); + return Ok(ResponseRequired::Yes); } // Safe to unwrap here let fabric = fabric.as_ref().as_ref().unwrap(); @@ -132,7 +132,7 @@ impl Case { None, )?; ctx.exch_ctx.exch.close(); - return Ok(()); + return Ok(ResponseRequired::Yes); } if Case::validate_sigma3_sign( @@ -151,7 +151,7 @@ impl Case { None, )?; ctx.exch_ctx.exch.close(); - return Ok(()); + return Ok(ResponseRequired::Yes); } // Only now do we add this message to the TT Hash @@ -174,10 +174,12 @@ impl Case { ctx.exch_ctx.exch.clear_data_boxed(); ctx.exch_ctx.exch.close(); - Ok(()) + Ok(ResponseRequired::Yes) } - pub fn handle_casesigma1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { + pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); + let rx_buf = ctx.rx.as_borrow_slice(); let root = get_root_node_struct(rx_buf)?; let r = Sigma1Req::from_tlv(&root)?; @@ -193,7 +195,7 @@ impl Case { None, )?; ctx.exch_ctx.exch.close(); - return Ok(()); + return Ok(ResponseRequired::Yes); } let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); @@ -239,7 +241,7 @@ impl Case { None, )?; ctx.exch_ctx.exch.close(); - return Ok(()); + return Ok(ResponseRequired::Yes); } let sign_len = Case::get_sigma2_sign( @@ -270,7 +272,7 @@ impl Case { tw.end_container()?; case_session.tt_hash.update(ctx.tx.as_borrow_slice())?; ctx.exch_ctx.exch.set_data_boxed(case_session); - Ok(()) + Ok(ResponseRequired::Yes) } fn get_session_clone_data( diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 679b1f5..9f7d16b 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -20,100 +20,30 @@ use std::sync::Arc; use crate::{ error::*, fabric::FabricMgr, - mdns::{self, Mdns}, - secure_channel::{common::*, pake::PAKE}, - sys::SysMdnsService, + secure_channel::common::*, tlv, transport::proto_demux::{self, ProtoCtx, ResponseRequired}, }; use log::{error, info}; use num; -use rand::prelude::*; -use super::case::Case; +use super::{case::Case, pake::PaseMgr}; /* Handle messages related to the Secure Channel */ pub struct SecureChannel { case: Case, - pake: Option<(PAKE, SysMdnsService)>, + pase: PaseMgr, } impl SecureChannel { - pub fn new(fabric_mgr: Arc) -> SecureChannel { + pub fn new(pase: PaseMgr, fabric_mgr: Arc) -> SecureChannel { SecureChannel { - pake: None, + pase, case: Case::new(fabric_mgr), } } - - pub fn open_comm_window(&mut self, salt: &[u8; 16], passwd: u32) -> Result<(), Error> { - let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); - let name = format!("{:016X}", name); - let mdns = Mdns::get()?.publish_service(&name, mdns::ServiceMode::Commissionable)?; - self.pake = Some((PAKE::new(salt, passwd), mdns)); - Ok(()) - } - - pub fn close_comm_window(&mut self) { - self.pake = None; - } - - fn mrpstandaloneack_handler(&mut self, _ctx: &mut ProtoCtx) -> Result { - info!("In MRP StandAlone ACK Handler"); - Ok(ResponseRequired::No) - } - - fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - info!("In PBKDF Param Request Handler"); - ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); - if let Some((pake, _)) = &mut self.pake { - pake.handle_pbkdfparamrequest(ctx)?; - } else { - error!("PASE Not enabled"); - create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None)?; - } - Ok(ResponseRequired::Yes) - } - - fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - info!("In PASE Pake1 Handler"); - ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); - if let Some((pake, _)) = &mut self.pake { - pake.handle_pasepake1(ctx)?; - } else { - error!("PASE Not enabled"); - create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None)?; - } - Ok(ResponseRequired::Yes) - } - - fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - info!("In PASE Pake3 Handler"); - if let Some((pake, _)) = &mut self.pake { - pake.handle_pasepake3(ctx)?; - // TODO: Currently we assume that PAKE is not successful and reset the PAKE object - self.pake = None; - } else { - error!("PASE Not enabled"); - create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None)?; - } - Ok(ResponseRequired::Yes) - } - - fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - info!("In CASE Sigma1 Handler"); - ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); - self.case.handle_casesigma1(ctx)?; - Ok(ResponseRequired::Yes) - } - - fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { - info!("In CASE Sigma3 Handler"); - self.case.handle_casesigma3(ctx)?; - Ok(ResponseRequired::Yes) - } } impl proto_demux::HandleProto for SecureChannel { @@ -121,15 +51,16 @@ impl proto_demux::HandleProto for SecureChannel { let proto_opcode: OpCode = num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); - info!("Received Data"); + info!("Received Opcode: {:?}", proto_opcode); + info!("Received Data:"); tlv::print_tlv_list(ctx.rx.as_borrow_slice()); let result = match proto_opcode { - OpCode::MRPStandAloneAck => self.mrpstandaloneack_handler(ctx), - OpCode::PBKDFParamRequest => self.pbkdfparamreq_handler(ctx), - OpCode::PASEPake1 => self.pasepake1_handler(ctx), - OpCode::PASEPake3 => self.pasepake3_handler(ctx), - OpCode::CASESigma1 => self.casesigma1_handler(ctx), - OpCode::CASESigma3 => self.casesigma3_handler(ctx), + OpCode::MRPStandAloneAck => Ok(ResponseRequired::No), + OpCode::PBKDFParamRequest => self.pase.pbkdfparamreq_handler(ctx), + OpCode::PASEPake1 => self.pase.pasepake1_handler(ctx), + OpCode::PASEPake3 => self.pase.pasepake3_handler(ctx), + OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), + OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); Err(Error::InvalidOpcode) diff --git a/matter/src/secure_channel/crypto.rs b/matter/src/secure_channel/crypto.rs index 68dac8d..f9481ba 100644 --- a/matter/src/secure_channel/crypto.rs +++ b/matter/src/secure_channel/crypto.rs @@ -38,7 +38,9 @@ pub trait CryptoSpake2 { fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error>; #[allow(non_snake_case)] - fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error>; + fn set_L(&mut self, l: &[u8]) -> Result<(), Error>; + #[allow(non_snake_case)] + fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>; #[allow(non_snake_case)] fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error>; #[allow(non_snake_case)] diff --git a/matter/src/secure_channel/crypto_mbedtls.rs b/matter/src/secure_channel/crypto_mbedtls.rs index 57eeff2..7231f63 100644 --- a/matter/src/secure_channel/crypto_mbedtls.rs +++ b/matter/src/secure_channel/crypto_mbedtls.rs @@ -114,9 +114,14 @@ impl CryptoSpake2 for CryptoMbedTLS { Ok(()) } + fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + self.L = EcPoint::from_binary(&mut self.group, l)?; + Ok(()) + } + #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { + fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve diff --git a/matter/src/secure_channel/crypto_openssl.rs b/matter/src/secure_channel/crypto_openssl.rs index 0f80f4c..84d6793 100644 --- a/matter/src/secure_channel/crypto_openssl.rs +++ b/matter/src/secure_channel/crypto_openssl.rs @@ -117,9 +117,14 @@ impl CryptoSpake2 for CryptoOpenSSL { Ok(()) } + fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { + self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; + Ok(()) + } + #[allow(non_snake_case)] #[allow(dead_code)] - fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { + fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { // From the Matter spec, // L = w1 * P // where P is the generator of the underlying elliptic curve diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index bb6503d..b096cf6 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -15,21 +15,26 @@ * limitations under the License. */ -use std::time::{Duration, SystemTime}; +use std::{ + sync::{Arc, Mutex}, + time::{Duration, SystemTime}, +}; use super::{ common::{create_sc_status_report, SCStatusCodes}, - spake2p::Spake2P, + spake2p::{Spake2P, VerifierData}, }; use crate::{ crypto, error::Error, - sys::SPAKE2_ITERATION_COUNT, + mdns::{self, Mdns}, + secure_channel::common::OpCode, + sys::SysMdnsService, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, - proto_demux::ProtoCtx, + proto_demux::{ProtoCtx, ResponseRequired}, queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, @@ -37,6 +42,79 @@ use crate::{ use log::{error, info}; use rand::prelude::*; +enum PaseMgrState { + Enabled(PAKE, SysMdnsService), + Disabled, +} + +pub struct PaseMgrInternal { + state: PaseMgrState, +} + +#[derive(Clone)] +// Could this lock be avoided? +pub struct PaseMgr(Arc>); + +impl PaseMgr { + pub fn new() -> Self { + Self(Arc::new(Mutex::new(PaseMgrInternal { + state: PaseMgrState::Disabled, + }))) + } + + pub fn enable_pase_session( + &mut self, + verifier: VerifierData, + discriminator: u16, + ) -> Result<(), Error> { + let mut s = self.0.lock().unwrap(); + let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); + let name = format!("{:016X}", name); + let mdns = Mdns::get()? + .publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; + s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); + Ok(()) + } + + pub fn disable_pase_session(&mut self) { + let mut s = self.0.lock().unwrap(); + s.state = PaseMgrState::Disabled; + } + + /// If the PASE Session is enabled, execute the closure, + /// if not enabled, generate SC Status Report + fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> + where + F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>, + { + let mut s = self.0.lock().unwrap(); + if let PaseMgrState::Enabled(pake, _) = &mut s.state { + f(pake, ctx) + } else { + error!("PASE Not enabled"); + create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None) + } + } + + pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); + self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; + Ok(ResponseRequired::Yes) + } + + pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); + self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; + Ok(ResponseRequired::Yes) + } + + pub fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; + self.disable_pase_session(); + Ok(ResponseRequired::Yes) + } +} + // This file basically deals with the handlers for the PASE secure channel protocol // TLV extraction and encoding is done in this file. // We create a Spake2p object and set it up in the exchange-data. This object then @@ -111,20 +189,17 @@ impl Default for PakeState { } } -#[derive(Default)] pub struct PAKE { - salt: [u8; 16], - passwd: u32, + pub verifier: VerifierData, state: PakeState, } impl PAKE { - pub fn new(salt: &[u8; 16], passwd: u32) -> Self { + pub fn new(verifier: VerifierData) -> Self { // TODO: Can any PBKDF2 calculation be pre-computed here PAKE { - passwd, - salt: *salt, - ..Default::default() + verifier, + state: Default::default(), } } @@ -176,8 +251,7 @@ impl PAKE { let pA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; - sd.spake2p - .start_verifier(self.passwd, SPAKE2_ITERATION_COUNT, &self.salt)?; + sd.spake2p.start_verifier(&self.verifier)?; sd.spake2p.handle_pA(pA, &mut pB, &mut cB)?; let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); @@ -231,8 +305,8 @@ impl PAKE { }; if !a.has_params { let params_resp = PBKDFParamRespParams { - count: SPAKE2_ITERATION_COUNT, - salt: OctetStr(&self.salt), + count: self.verifier.count, + salt: OctetStr(&self.verifier.salt), }; resp.params = Some(params_resp); } diff --git a/matter/src/secure_channel/spake2p.rs b/matter/src/secure_channel/spake2p.rs index 3870bcd..ad0cac5 100644 --- a/matter/src/secure_channel/spake2p.rs +++ b/matter/src/secure_channel/spake2p.rs @@ -15,8 +15,13 @@ * limitations under the License. */ -use crate::crypto::{self, HmacSha256}; +use crate::{ + crypto::{self, HmacSha256}, + sys, +}; use byteorder::{ByteOrder, LittleEndian}; +use log::error; +use rand::prelude::*; use subtle::ConstantTimeEq; use crate::{ @@ -74,6 +79,10 @@ const SPAKE2P_KEY_CONFIRM_INFO: [u8; 16] = *b"ConfirmationKeys"; const SPAKE2P_CONTEXT_PREFIX: [u8; 26] = *b"CHIP PAKE V1 Commissioning"; const CRYPTO_GROUP_SIZE_BYTES: usize = 32; const CRYPTO_W_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + 8; +const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1; + +const MAX_SALT_SIZE_BYTES: usize = 32; +const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; #[cfg(feature = "crypto_openssl")] fn crypto_spake2_new() -> Result, Error> { @@ -96,6 +105,50 @@ impl Default for Spake2P { } } +pub struct VerifierData { + pub data: VerifierOption, + // For the VerifierOption::Verifier, the following fields only serve + // information purposes + pub salt: [u8; MAX_SALT_SIZE_BYTES], + pub count: u32, +} + +pub enum VerifierOption { + /// With Password + Password(u32), + /// With Verifier + Verifier([u8; VERIFIER_SIZE_BYTES]), +} + +impl VerifierData { + pub fn new_with_pw(pw: u32) -> Self { + let mut s = Self { + salt: [0; MAX_SALT_SIZE_BYTES], + count: sys::SPAKE2_ITERATION_COUNT, + data: VerifierOption::Password(pw), + }; + rand::thread_rng().fill_bytes(&mut s.salt); + s + } + + pub fn new(verifier: &[u8], count: u32, salt: &[u8]) -> Self { + let mut v = [0_u8; VERIFIER_SIZE_BYTES]; + let mut s = [0_u8; MAX_SALT_SIZE_BYTES]; + + let slice = &mut v[..verifier.len()]; + slice.copy_from_slice(verifier); + + let slice = &mut s[..salt.len()]; + slice.copy_from_slice(salt); + + Self { + data: VerifierOption::Verifier(v), + count, + salt: s, + } + } +} + impl Spake2P { pub fn new() -> Self { Spake2P { @@ -132,17 +185,31 @@ impl Spake2P { let _ = pbkdf2_hmac(&pw_str, iter as usize, salt, w0w1s); } - pub fn start_verifier(&mut self, pw: u32, iter: u32, salt: &[u8]) -> Result<(), Error> { - let mut w0w1s: [u8; (2 * CRYPTO_W_SIZE_BYTES)] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; - Spake2P::get_w0w1s(pw, iter, salt, &mut w0w1s); + pub fn start_verifier(&mut self, verifier: &VerifierData) -> Result<(), Error> { self.crypto_spake2 = Some(crypto_spake2_new()?); + match verifier.data { + VerifierOption::Password(pw) => { + // Derive w0 and L from the password + let mut w0w1s: [u8; (2 * CRYPTO_W_SIZE_BYTES)] = [0; (2 * CRYPTO_W_SIZE_BYTES)]; + Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s); - let w0s_len = w0w1s.len() / 2; - if let Some(crypto_spake2) = &mut self.crypto_spake2 { - crypto_spake2.set_w0_from_w0s(&w0w1s[0..w0s_len])?; - crypto_spake2.set_L(&w0w1s[w0s_len..])?; + let w0s_len = w0w1s.len() / 2; + if let Some(crypto_spake2) = &mut self.crypto_spake2 { + crypto_spake2.set_w0_from_w0s(&w0w1s[0..w0s_len])?; + crypto_spake2.set_L_from_w1s(&w0w1s[w0s_len..])?; + } + } + VerifierOption::Verifier(v) => { + // Extract w0 and L from the verifier + if v.len() != CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES { + error!("Verifier of invalid length"); + } + if let Some(crypto_spake2) = &mut self.crypto_spake2 { + crypto_spake2.set_w0(&v[0..CRYPTO_GROUP_SIZE_BYTES])?; + crypto_spake2.set_L(&v[CRYPTO_GROUP_SIZE_BYTES..])?; + } + } } - self.mode = Spake2Mode::Verifier(Spake2VerifierState::Init); Ok(()) } @@ -164,6 +231,7 @@ impl Spake2P { Spake2P::get_Ke_and_cAcB(&TT, pA, pB, &mut self.Ke, &mut self.cA, cB)?; } } + // We are finished with using the crypto_spake2 now self.crypto_spake2 = None; self.mode = Spake2Mode::Verifier(Spake2VerifierState::PendingConfirmation); diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index d9b7d10..60bf73a 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -119,6 +119,10 @@ impl<'a> UtfStr<'a> { pub fn new(str: &'a [u8]) -> Self { Self(str) } + + pub fn to_string(self) -> Result { + String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) + } } impl<'a> ToTLV for UtfStr<'a> { @@ -127,6 +131,12 @@ impl<'a> ToTLV for UtfStr<'a> { } } +impl<'a> FromTLV<'a> for UtfStr<'a> { + fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { + t.slice().map(UtfStr) + } +} + /// Implements OctetString from the spec #[derive(Debug, Copy, Clone, PartialEq)] pub struct OctetStr<'a>(pub &'a [u8]); diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 19cb6af..923d499 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -29,6 +29,7 @@ use matter::{ error::Error, fabric::FabricMgr, interaction_model::{core::OpCode, InteractionModel}, + secure_channel::pake::PaseMgr, tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, transport::proto_demux::HandleProto, @@ -99,12 +100,20 @@ impl ImEngine { let dev_att = Box::new(DummyDevAtt {}); let fabric_mgr = Arc::new(FabricMgr::new().unwrap()); let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let pase_mgr = PaseMgr::new(); acl_mgr.erase_all(); let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); // Only allow the standard peer node id of the IM Engine default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl_mgr.add(default_acl).unwrap(); - let dm = DataModel::new(&dev_det, dev_att, fabric_mgr, acl_mgr.clone()).unwrap(); + let dm = DataModel::new( + dev_det, + dev_att, + fabric_mgr.clone(), + acl_mgr.clone(), + pase_mgr, + ) + .unwrap(); { let mut d = dm.node.write().unwrap();