diff --git a/matter/src/core.rs b/matter/src/core.rs index 71eb6c3..39a903f 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -62,8 +62,10 @@ impl Matter { let fabric_mgr = Arc::new(FabricMgr::new()?); let acl_mgr = Arc::new(AclMgr::new()?); + let mut pase = PaseMgr::new(); let open_comm_window = fabric_mgr.is_empty(); - let data_model = DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr)?; + let data_model = + DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; let mut matter = Box::new(Matter { transport_mgr: transport::mgr::Mgr::new()?, data_model, @@ -73,12 +75,11 @@ impl Matter { Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); matter.transport_mgr.register_protocol(interaction_model)?; - let mut pase = PaseMgr::new(); if open_comm_window { pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; } - let secure_channel = Box::new(SecureChannel::new(pase.clone(), matter.fabric_mgr.clone())); + let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); matter.transport_mgr.register_protocol(secure_channel)?; Ok(matter) } diff --git a/matter/src/data_model/core.rs b/matter/src/data_model/core.rs index 4b0db23..19516fb 100644 --- a/matter/src/data_model/core.rs +++ b/matter/src/data_model/core.rs @@ -36,6 +36,7 @@ use crate::{ }, InteractionConsumer, Transaction, }, + secure_channel::pake::PaseMgr, tlv::{TLVArray, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode}, }; @@ -54,6 +55,7 @@ impl DataModel { dev_att: Box, fabric_mgr: Arc, acl_mgr: Arc, + pase_mgr: PaseMgr, ) -> Result { let dm = DataModel { node: Arc::new(RwLock::new(Node::new()?)), @@ -62,7 +64,14 @@ impl DataModel { { let mut node = dm.node.write()?; node.set_changes_cb(Box::new(dm.clone())); - device_type_add_root_node(&mut node, dev_details, dev_att, fabric_mgr, acl_mgr)?; + device_type_add_root_node( + &mut node, + dev_details, + dev_att, + fabric_mgr, + acl_mgr, + pase_mgr, + )?; } Ok(dm) } diff --git a/matter/src/data_model/device_types.rs b/matter/src/data_model/device_types.rs index 71a2cf4..4d68f9c 100644 --- a/matter/src/data_model/device_types.rs +++ b/matter/src/data_model/device_types.rs @@ -28,6 +28,7 @@ use super::system_model::access_control::AccessControlCluster; use crate::acl::AclMgr; use crate::error::*; use crate::fabric::FabricMgr; +use crate::secure_channel::pake::PaseMgr; use std::sync::Arc; use std::sync::RwLockWriteGuard; @@ -39,6 +40,7 @@ pub fn device_type_add_root_node( dev_att: Box, fabric_mgr: Arc, acl_mgr: Arc, + pase_mgr: PaseMgr, ) -> Result { // Add the root endpoint let endpoint = node.add_endpoint()?; @@ -52,7 +54,7 @@ pub fn device_type_add_root_node( let failsafe = general_commissioning.failsafe(); node.add_cluster(0, general_commissioning)?; node.add_cluster(0, NwCommCluster::new()?)?; - node.add_cluster(0, AdminCommCluster::new()?)?; + node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?; node.add_cluster( 0, NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?, diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 6e7d7b7..af8149b 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -18,6 +18,8 @@ use crate::cmd_enter; use crate::data_model::objects::*; use crate::interaction_model::core::IMStatusCode; +use crate::secure_channel::pake::PaseMgr; +use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; use crate::{error::*, interaction_model::command::CommandReq}; use log::{error, info}; @@ -74,7 +76,7 @@ fn attr_admin_vid_new() -> Result { } pub struct AdminCommCluster { - window_status: WindowStatus, + pase_mgr: PaseMgr, base: Cluster, } @@ -89,23 +91,16 @@ impl ClusterType for AdminCommCluster { fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { match num::FromPrimitive::from_u16(attr.attr_id) { Some(Attributes::WindowStatus) => { - let status = self.window_status as u8; + let status = 1_u8; encoder.encode(EncodeValue::Value(&status)) } Some(Attributes::AdminVendorId) => { - let vid = if self.window_status == WindowStatus::WindowNotOpen { - Nullable::Null - } else { - Nullable::NotNull(1_u8) - }; + let vid = Nullable::NotNull(1_u8); + encoder.encode(EncodeValue::Value(&vid)) } Some(Attributes::AdminFabricIndex) => { - let vid = if self.window_status == WindowStatus::WindowNotOpen { - Nullable::Null - } else { - Nullable::NotNull(1_u8) - }; + let vid = Nullable::NotNull(1_u8); encoder.encode(EncodeValue::Value(&vid)) } _ => { @@ -129,9 +124,9 @@ impl ClusterType for AdminCommCluster { } impl AdminCommCluster { - pub fn new() -> Result, Error> { + pub fn new(pase_mgr: PaseMgr) -> Result, Error> { let mut c = Box::new(AdminCommCluster { - window_status: WindowStatus::WindowNotOpen, + pase_mgr, base: Cluster::new(ID)?, }); c.base.add_attribute(attr_window_status_new()?)?; @@ -145,9 +140,11 @@ impl AdminCommCluster { cmd_req: &mut CommandReq, ) -> Result<(), IMStatusCode> { cmd_enter!("Open Commissioning Window"); - let _req = + let req = OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?; - self.window_status = WindowStatus::EnhancedWindowOpen; + let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); + self.pase_mgr + .enable_pase_session(verifier, req.discriminator)?; Err(IMStatusCode::Sucess) } } @@ -156,8 +153,8 @@ impl AdminCommCluster { #[tlvargs(lifetime = "'a")] pub struct OpenCommWindowReq<'a> { _timeout: u16, - _verifier: OctetStr<'a>, - _discriminator: u16, - _iterations: u32, - _salt: OctetStr<'a>, + verifier: OctetStr<'a>, + discriminator: u16, + iterations: u32, + salt: OctetStr<'a>, } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 2bf7acd..b006d30 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -42,13 +42,13 @@ use crate::{ use log::{error, info}; use rand::prelude::*; -enum PaseSessionState { +enum PaseMgrState { Enabled(PAKE, SysMdnsService), Disabled, } pub struct PaseMgrInternal { - state: PaseSessionState, + state: PaseMgrState, } #[derive(Clone)] @@ -58,7 +58,7 @@ pub struct PaseMgr(Arc>); impl PaseMgr { pub fn new() -> Self { Self(Arc::new(Mutex::new(PaseMgrInternal { - state: PaseSessionState::Disabled, + state: PaseMgrState::Disabled, }))) } @@ -72,13 +72,13 @@ impl PaseMgr { let name = format!("{:016X}", name); let mdns = Mdns::get()? .publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; - s.state = PaseSessionState::Enabled(PAKE::new(verifier), mdns); + s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); Ok(()) } pub fn disable_pase_session(&mut self) { let mut s = self.0.lock().unwrap(); - s.state = PaseSessionState::Disabled; + s.state = PaseMgrState::Disabled; } /// If the PASE Session is enabled, execute the closure, @@ -88,7 +88,7 @@ impl PaseMgr { F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>, { let mut s = self.0.lock().unwrap(); - if let PaseSessionState::Enabled(pake, _) = &mut s.state { + if let PaseMgrState::Enabled(pake, _) = &mut s.state { f(pake, ctx) } else { error!("PASE Not enabled"); @@ -190,7 +190,7 @@ impl Default for PakeState { } pub struct PAKE { - verifier: VerifierData, + pub verifier: VerifierData, state: PakeState, } diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index ddd8b30..e9bedf0 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -29,6 +29,7 @@ use matter::{ error::Error, fabric::FabricMgr, interaction_model::{core::OpCode, InteractionModel}, + secure_channel::pake::PaseMgr, tlv::{TLVWriter, TagType, ToTLV}, transport::packet::Packet, transport::proto_demux::HandleProto, @@ -97,12 +98,20 @@ impl ImEngine { let dev_att = Box::new(DummyDevAtt {}); let fabric_mgr = Arc::new(FabricMgr::new().unwrap()); let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); + let pase_mgr = PaseMgr::new(); acl_mgr.erase_all(); let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); // Only allow the standard peer node id of the IM Engine default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl_mgr.add(default_acl).unwrap(); - let dm = DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr.clone()).unwrap(); + let dm = DataModel::new( + dev_det, + dev_att, + fabric_mgr.clone(), + acl_mgr.clone(), + pase_mgr, + ) + .unwrap(); { let mut d = dm.node.write().unwrap();