From e2277a17a49835caaf7067d7dfc281e8bbdb686e Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 10 Jun 2023 14:01:35 +0000 Subject: [PATCH] Make Matter covariant over its lifetime --- examples/onoff_light/src/main.rs | 88 +++---- matter/src/core.rs | 52 +++-- matter/src/data_model/root_endpoint.rs | 21 +- .../src/data_model/sdm/admin_commissioning.rs | 10 +- matter/src/data_model/sdm/noc.rs | 8 +- matter/src/fabric.rs | 6 +- matter/src/interaction_model/core.rs | 21 +- matter/src/mdns.rs | 218 +++++++++++------- matter/src/secure_channel/common.rs | 5 +- matter/src/secure_channel/core.rs | 16 +- matter/src/secure_channel/pake.rs | 6 +- matter/src/secure_channel/status_report.rs | 1 + matter/src/transport/core.rs | 39 ++-- matter/src/transport/exchange.rs | 2 +- matter/src/transport/mod.rs | 1 + matter/src/transport/packet.rs | 107 ++++++++- matter/src/transport/pipe.rs | 94 ++++++++ matter/src/transport/session.rs | 75 +----- matter/src/utils/select.rs | 3 + matter/src/utils/writebuf.rs | 4 + matter/tests/common/echo_cluster.rs | 6 +- matter/tests/common/im_engine.rs | 4 +- 22 files changed, 508 insertions(+), 279 deletions(-) create mode 100644 matter/src/transport/pipe.rs diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index baf2022..b2e5091 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -30,7 +30,7 @@ use matter::data_model::root_endpoint; use matter::data_model::system_model::descriptor; use matter::error::Error; use matter::interaction_model::core::InteractionModel; -use matter::mdns::builtin::Mdns; +use matter::mdns::builtin::{Mdns, MdnsRxBuf, MdnsTxBuf}; use matter::persist; use matter::secure_channel::spake2p::VerifierData; use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; @@ -46,7 +46,7 @@ mod dev_att; fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() .stack_size(120 * 1024) - .spawn(move || run()) + .spawn(run) .unwrap(); thread.join().unwrap() @@ -63,10 +63,10 @@ fn run() -> Result<(), Error> { initialize_logger(); info!( - "Matter memory: mDNS={}, Transport={} (of which Matter={})", + "Matter memory: mDNS={}, Matter={}, Transport={}", core::mem::size_of::(), - core::mem::size_of::(), core::mem::size_of::(), + core::mem::size_of::(), ); let (ipv4_addr, ipv6_addr) = initialize_network()?; @@ -78,7 +78,7 @@ fn run() -> Result<(), Error> { Some(ipv6_addr.octets()), ); - let (mut mdns, mut mdns_runner) = mdns.split(); + let (mdns, mut mdns_runner) = mdns.split(); //let (mut mdns, mdns_runner) = (matter::mdns::astro::AstroMdns::new()?, core::future::pending::pending()); //let (mut mdns, mdns_runner) = (matter::mdns::DummyMdns {}, core::future::pending::pending()); @@ -86,7 +86,7 @@ fn run() -> Result<(), Error> { let matter = Matter::new_default( // vid/pid should match those in the DAC - BasicInfoConfig { + &BasicInfoConfig { vid: 0xFFF1, pid: 0x8000, hw_ver: 2, @@ -96,7 +96,7 @@ fn run() -> Result<(), Error> { device_name: "OnOff Light", }, &dev_att, - &mut mdns, + &mdns, matter::MATTER_PORT, ); @@ -106,12 +106,13 @@ fn run() -> Result<(), Error> { let psm = persist::FilePsm::new(psm_path)?; let mut buf = [0; 4096]; + let buf = &mut buf; - if let Some(data) = psm.load("acls", &mut buf)? { + if let Some(data) = psm.load("acls", buf)? { matter.load_acls(data)?; } - if let Some(data) = psm.load("fabrics", &mut buf)? { + if let Some(data) = psm.load("fabrics", buf)? { matter.load_fabrics(data)?; } @@ -123,12 +124,33 @@ fn run() -> Result<(), Error> { verifier: VerifierData::new_with_pw(123456, *matter.borrow()), discriminator: 250, }, - &mut buf, + buf, )?; - let matter = &matter; + let node = Node { + id: 0, + endpoints: &[ + root_endpoint::endpoint(0), + Endpoint { + id: 1, + device_type: DEV_TYPE_ON_OFF_LIGHT, + clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], + }, + ], + }; + + let mut handler = handler(&matter); + + let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); + + let mut rx_buf = [0; MAX_RX_BUF_SIZE]; + let mut tx_buf = [0; MAX_TX_BUF_SIZE]; + + let im = &mut im; let mdns_runner = &mut mdns_runner; let transport = &mut transport; + let rx_buf = &mut rx_buf; + let tx_buf = &mut tx_buf; let mut io_fut = pin!(async move { let udp = UdpListener::new(SocketAddr::new( @@ -138,13 +160,9 @@ fn run() -> Result<(), Error> { .await?; loop { - let mut rx_buf = [0; MAX_RX_BUF_SIZE]; - let mut tx_buf = [0; MAX_TX_BUF_SIZE]; + let (len, addr) = udp.recv(rx_buf).await?; - let (len, addr) = udp.recv(&mut rx_buf).await?; - - let mut completion = - transport.recv(Address::Udp(addr), &mut rx_buf[..len], &mut tx_buf); + let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf); while let Some(action) = completion.next_action()? { match action { @@ -152,38 +170,19 @@ fn run() -> Result<(), Error> { udp.send(addr.unwrap_udp(), buf).await?; } RecvAction::Interact(mut ctx) => { - let node = Node { - id: 0, - endpoints: &[ - root_endpoint::endpoint(0), - Endpoint { - id: 1, - device_type: DEV_TYPE_ON_OFF_LIGHT, - clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER], - }, - ], - }; - - let mut handler = handler(matter); - - let mut im = - InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler)); - - if im.handle(&mut ctx)? { - if ctx.send()? { - udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) - .await?; - } + if im.handle(&mut ctx)? && ctx.send()? { + udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice()) + .await?; } } } } - if let Some(data) = transport.matter().store_fabrics(&mut buf)? { + if let Some(data) = transport.matter().store_fabrics(buf)? { psm.store("fabrics", data)?; } - if let Some(data) = transport.matter().store_acls(&mut buf)? { + if let Some(data) = transport.matter().store_acls(buf)? { psm.store("acls", data)?; } } @@ -192,7 +191,12 @@ fn run() -> Result<(), Error> { Ok::<_, matter::error::Error>(()) }); - let mut mdns_fut = pin!(async move { mdns_runner.run().await }); + let mut tx_buf = MdnsTxBuf::uninit(); + let mut rx_buf = MdnsRxBuf::uninit(); + let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; + + let mut mdns_fut = pin!(async move { mdns_runner.run_udp(tx_buf, rx_buf).await }); let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut,).await.unwrap() }); diff --git a/matter/src/core.rs b/matter/src/core.rs index dacddbd..c6c5dc1 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -48,10 +48,10 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, pub pase_mgr: RefCell, pub failsafe: RefCell, - pub mdns_mgr: RefCell>, + pub mdns_mgr: MdnsMgr<'a>, pub epoch: Epoch, pub rand: Rand, - pub dev_det: BasicInfoConfig<'a>, + pub dev_det: &'a BasicInfoConfig<'a>, pub dev_att: &'a dyn DevAttDataFetcher, pub port: u16, } @@ -60,9 +60,9 @@ impl<'a> Matter<'a> { #[cfg(feature = "std")] #[inline(always)] pub fn new_default( - dev_det: BasicInfoConfig<'a>, + dev_det: &'a BasicInfoConfig<'a>, dev_att: &'a dyn DevAttDataFetcher, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, port: u16, ) -> Self { use crate::utils::epoch::sys_epoch; @@ -79,9 +79,9 @@ impl<'a> Matter<'a> { /// this object to return the device attestation details when queried upon. #[inline(always)] pub fn new( - dev_det: BasicInfoConfig<'a>, + dev_det: &'a BasicInfoConfig<'a>, dev_att: &'a dyn DevAttDataFetcher, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, port: u16, @@ -91,13 +91,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - mdns_mgr: RefCell::new(MdnsMgr::new( - dev_det.vid, - dev_det.pid, - dev_det.device_name, - port, - mdns, - )), + mdns_mgr: MdnsMgr::new(dev_det.vid, dev_det.pid, dev_det.device_name, port, mdns), epoch, rand, dev_det, @@ -107,7 +101,7 @@ impl<'a> Matter<'a> { } pub fn dev_det(&self) -> &BasicInfoConfig<'_> { - &self.dev_det + self.dev_det } pub fn dev_att(&self) -> &dyn DevAttDataFetcher { @@ -119,9 +113,7 @@ impl<'a> Matter<'a> { } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.fabric_mgr - .borrow_mut() - .load(data, &mut self.mdns_mgr.borrow_mut()) + self.fabric_mgr.borrow_mut().load(data, &self.mdns_mgr) } pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { @@ -148,7 +140,7 @@ impl<'a> Matter<'a> { if !self.pase_mgr.borrow().is_pase_session_enabled() && self.fabric_mgr.borrow().is_empty() { print_pairing_code_and_qr( - &self.dev_det, + self.dev_det, &dev_comm, DiscoveryCapabilities::default(), buf, @@ -157,7 +149,7 @@ impl<'a> Matter<'a> { self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, dev_comm.discriminator, - &mut self.mdns_mgr.borrow_mut(), + &self.mdns_mgr, )?; Ok(true) @@ -191,12 +183,30 @@ impl<'a> Borrow> for Matter<'a> { } } -impl<'a> Borrow>> for Matter<'a> { - fn borrow(&self) -> &RefCell> { +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &MdnsMgr<'a> { &self.mdns_mgr } } +impl<'a> Borrow> for Matter<'a> { + fn borrow(&self) -> &BasicInfoConfig<'a> { + self.dev_det + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &(dyn DevAttDataFetcher + 'a) { + self.dev_att + } +} + +impl<'a> Borrow for Matter<'a> { + fn borrow(&self) -> &(dyn Mdns + 'a) { + self.mdns_mgr.mdns + } +} + impl<'a> Borrow for Matter<'a> { fn borrow(&self) -> &Epoch { &self.epoch diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 859d2bc..1bc22fe 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -7,7 +7,6 @@ use crate::{ mdns::MdnsMgr, secure_channel::pake::PaseMgr, utils::{epoch::Epoch, rand::Rand}, - Matter, }; use super::{ @@ -55,11 +54,23 @@ pub fn endpoint(id: EndptId) -> Endpoint<'static> { } } -pub fn handler<'a>(endpoint_id: u16, matter: &'a Matter<'a>) -> RootEndpointHandler<'a> { +pub fn handler<'a, T>(endpoint_id: u16, matter: &'a T) -> RootEndpointHandler<'a> +where + T: Borrow> + + Borrow + + Borrow> + + Borrow> + + Borrow> + + Borrow> + + Borrow> + + Borrow + + Borrow + + 'a, +{ wrap( endpoint_id, - matter.dev_det(), - matter.dev_att(), + matter.borrow(), + matter.borrow(), matter.borrow(), matter.borrow(), matter.borrow(), @@ -79,7 +90,7 @@ pub fn wrap<'a>( fabric: &'a RefCell, acl: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, epoch: Epoch, rand: Rand, ) -> RootEndpointHandler<'a> { diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index b63aa2e..9364311 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -102,15 +102,11 @@ pub struct OpenCommWindowReq<'a> { pub struct AdminCommCluster<'a> { data_ver: Dataver, pase_mgr: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, } impl<'a> AdminCommCluster<'a> { - pub fn new( - pase_mgr: &'a RefCell, - mdns_mgr: &'a RefCell>, - rand: Rand, - ) -> Self { + pub fn new(pase_mgr: &'a RefCell, mdns_mgr: &'a MdnsMgr<'a>, rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), pase_mgr, @@ -159,7 +155,7 @@ impl<'a> AdminCommCluster<'a> { self.pase_mgr.borrow_mut().enable_pase_session( verifier, req.discriminator, - &mut self.mdns_mgr.borrow_mut(), + self.mdns_mgr, )?; Ok(()) diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index b8dda3c..f347b13 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -222,7 +222,7 @@ pub struct NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, } impl<'a> NocCluster<'a> { @@ -231,7 +231,7 @@ impl<'a> NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a RefCell>, + mdns_mgr: &'a MdnsMgr<'a>, epoch: Epoch, rand: Rand, ) -> Self { @@ -383,7 +383,7 @@ impl<'a> NocCluster<'a> { let fab_idx = self .fabric_mgr .borrow_mut() - .add(fabric, &mut self.mdns_mgr.borrow_mut()) + .add(fabric, self.mdns_mgr) .map_err(|_| NocStatus::TableFull)?; self.add_acl(fab_idx, r.case_admin_subject)?; @@ -455,7 +455,7 @@ impl<'a> NocCluster<'a> { if self .fabric_mgr .borrow_mut() - .remove(req.fab_idx, &mut self.mdns_mgr.borrow_mut()) + .remove(req.fab_idx, self.mdns_mgr) .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index f6f64ef..04369ca 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -200,7 +200,7 @@ impl FabricMgr { } } - pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + pub fn load(&mut self, data: &[u8], mdns_mgr: &MdnsMgr) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } @@ -241,7 +241,7 @@ impl FabricMgr { self.changed } - pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { + pub fn add(&mut self, f: Fabric, mdns_mgr: &MdnsMgr) -> Result { let slot = self.fabrics.iter().position(|x| x.is_none()); if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { @@ -265,7 +265,7 @@ impl FabricMgr { } } - pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { + pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &MdnsMgr) -> Result<(), Error> { if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 0686061..cc763a8 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -204,7 +204,7 @@ impl<'a, 'b> Transaction<'a, 'b> { } /* Interaction Model ID as per the Matter Spec */ -const PROTO_ID_INTERACTION_MODEL: usize = 0x01; +pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; const MAX_RESUME_PATHS: usize = 32; const MAX_RESUME_DATAVER_FILTERS: usize = 32; @@ -228,8 +228,7 @@ pub enum Interaction<'a> { impl<'a> Interaction<'a> { fn new(rx: &'a Packet, transaction: &mut Transaction) -> Result, Error> { - let opcode: OpCode = - num::FromPrimitive::from_u8(rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; + let opcode: OpCode = rx.get_proto_opcode()?; let rx_data = rx.as_slice(); @@ -303,7 +302,7 @@ impl<'a> Interaction<'a> { } fn create_status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::StatusResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -332,7 +331,7 @@ impl<'a> ReadReq<'a> { } fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = Self::reserve_long_read_space(tx)?; @@ -410,7 +409,7 @@ impl<'a> WriteReq<'a> { Ok(false) } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::WriteResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -459,7 +458,7 @@ impl<'a> InvReq<'a> { Ok(false) } else { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::InvokeResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -503,7 +502,7 @@ impl<'a> InvReq<'a> { impl TimedReq { pub fn process(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result<(), Error> { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::StatusResponse as u8); let mut tw = TLVWriter::new(tx.get_writebuf()?); @@ -547,7 +546,7 @@ impl<'a> SubscribeReq<'a> { } fn initiate(&self, tx: &mut Packet, transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = ReadReq::reserve_long_read_space(tx)?; @@ -615,7 +614,7 @@ pub struct ResumeReadReq { impl ResumeReadReq { fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); tx.set_proto_opcode(OpCode::ReportData as u8); let mut tw = ReadReq::reserve_long_read_space(tx)?; @@ -679,7 +678,7 @@ pub struct ResumeSubscribeReq { impl ResumeSubscribeReq { fn initiate(&self, tx: &mut Packet, _transaction: &mut Transaction) -> Result { - tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16); + tx.set_proto_id(PROTO_ID_INTERACTION_MODEL); if self.resume_path.is_some() { tx.set_proto_opcode(OpCode::ReportData as u8); diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index eaba7ee..3bc4a1d 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -21,7 +21,7 @@ use crate::error::Error; pub trait Mdns { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -30,8 +30,7 @@ pub trait Mdns { txt_kvs: &[(&str, &str)], ) -> Result<(), Error>; - fn remove(&mut self, name: &str, service: &str, protocol: &str, port: u16) - -> Result<(), Error>; + fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error>; } impl Mdns for &mut T @@ -39,7 +38,7 @@ where T: Mdns, { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -50,13 +49,7 @@ where (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) } - fn remove( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { + fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { (**self).remove(name, service, protocol, port) } } @@ -65,7 +58,7 @@ pub struct DummyMdns; impl Mdns for DummyMdns { fn add( - &mut self, + &self, _name: &str, _service: &str, _protocol: &str, @@ -77,7 +70,7 @@ impl Mdns for DummyMdns { } fn remove( - &mut self, + &self, _name: &str, _service: &str, _protocol: &str, @@ -101,11 +94,11 @@ pub struct MdnsMgr<'a> { /// Product ID pid: u16, /// Device name - device_name: heapless::String<32>, + device_name: &'a str, /// Matter port matter_port: u16, /// mDns service - mdns: &'a mut dyn Mdns, + pub(crate) mdns: &'a dyn Mdns, } impl<'a> MdnsMgr<'a> { @@ -113,14 +106,14 @@ impl<'a> MdnsMgr<'a> { pub fn new( vid: u16, pid: u16, - device_name: &str, + device_name: &'a str, matter_port: u16, - mdns: &'a mut dyn Mdns, + mdns: &'a dyn Mdns, ) -> Self { Self { vid, pid, - device_name: device_name.chars().take(32).collect(), + device_name, matter_port, mdns, } @@ -130,7 +123,7 @@ impl<'a> MdnsMgr<'a> { /// name - is the service name (comma separated subtypes may follow) /// mode - the current service mode #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { ServiceMode::Commissioned => { self.mdns @@ -143,7 +136,7 @@ impl<'a> MdnsMgr<'a> { let txt_kvs = [ ("D", discriminator_str.as_str()), ("CM", "1"), - ("DN", self.device_name.as_str()), + ("DN", self.device_name), ("VP", &vp), ("SII", "5000"), /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */ @@ -166,7 +159,7 @@ impl<'a> MdnsMgr<'a> { } } - pub fn unpublish_service(&mut self, name: &str, mode: ServiceMode) -> Result<(), Error> { + pub fn unpublish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { match mode { ServiceMode::Commissioned => { self.mdns.remove(name, "_matter", "_tcp", self.matter_port) @@ -216,6 +209,7 @@ impl<'a> MdnsMgr<'a> { pub mod builtin { use core::cell::RefCell; use core::fmt::Write; + use core::mem::MaybeUninit; use core::pin::pin; use core::str::FromStr; @@ -224,15 +218,16 @@ pub mod builtin { use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; - use embassy_futures::select::select; - use embassy_sync::blocking_mutex::raw::NoopRawMutex; + use embassy_futures::select::{select, select3}; use embassy_time::{Duration, Timer}; use log::info; use crate::error::{Error, ErrorCode}; - use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; + use crate::transport::pipe::{Chunk, Pipe}; use crate::transport::udp::UdpListener; - use crate::utils::select::EitherUnwrap; + use crate::utils::select::{EitherUnwrap, Notification}; const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), @@ -244,6 +239,9 @@ pub mod builtin { const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + pub type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; + pub type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; + #[allow(clippy::too_many_arguments)] pub fn create_record( id: u16, @@ -382,8 +380,6 @@ pub mod builtin { Ok(target.len()) } - pub type Notification = embassy_sync::signal::Signal; - #[derive(Debug, Clone)] struct MdnsEntry { key: heapless::String<64>, @@ -407,7 +403,6 @@ pub mod builtin { ipv6: Option<[u8; 16]>, entries: RefCell>, notification: Notification, - udp: RefCell>, } impl<'a> Mdns<'a> { @@ -420,7 +415,6 @@ pub mod builtin { ipv6, entries: RefCell::new(heapless::Vec::new()), notification: Notification::new(), - udp: RefCell::new(None), } } @@ -428,19 +422,6 @@ pub mod builtin { (MdnsApi(&*self), MdnsRunner(&*self)) } - async fn bind(&self) -> Result<(), Error> { - if self.udp.borrow().is_none() { - *self.udp.borrow_mut() = - Some(UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?); - } - - Ok(()) - } - - pub fn close(&mut self) { - *self.udp.borrow_mut() = None; - } - fn key( &self, name: &str, @@ -546,15 +527,72 @@ pub mod builtin { pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>); impl<'a, 'b> MdnsRunner<'a, 'b> { - pub async fn run(&mut self) -> Result<(), Error> { - let mut broadcast = pin!(self.broadcast()); - let mut respond = pin!(self.respond()); + pub async fn run_udp( + &mut self, + tx_buf: &mut MdnsTxBuf, + rx_buf: &mut MdnsRxBuf, + ) -> Result<(), Error> { + let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run(&mut self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast(tx_pipe)); + let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); select(&mut broadcast, &mut respond).await.unwrap() } #[allow(clippy::await_holding_refcell_ref)] - async fn broadcast(&self) -> Result<(), Error> { + async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { select( self.0.notification.wait(), @@ -564,51 +602,74 @@ pub mod builtin { let mut index = 0; - loop { - let entry = self.0.entries.borrow().get(index).cloned(); + 'outer: loop { + for (addr, port) in IP_BROADCAST_ADDRS { + loop { + { + let mut data = tx_pipe.data.lock().await; - if let Some(entry) = entry { - info!("Broadasting mDNS entry {}", &entry.key); + if data.chunk.is_none() { + let entries = self.0.entries.borrow(); + let entry = entries.get(index); - self.0.bind().await?; + if let Some(entry) = entry { + info!( + "Broadasting mDNS entry {} on {}:{}", + &entry.key, addr, port + ); - let udp = self.0.udp.borrow(); - let udp = udp.as_ref().unwrap(); + let len = entry.record.len(); + data.buf[..len].copy_from_slice(&entry.record); + drop(entries); - for (addr, port) in IP_BROADCAST_ADDRS { - udp.send(SocketAddr::new(addr, port), &entry.record).await?; + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(SocketAddr::new(addr, port)), + }); + + tx_pipe.data_supplied_notification.signal(()); + } else { + break 'outer; + } + + break; + } + } + + tx_pipe.data_consumed_notification.wait().await; } - - index += 1; - } else { - break; } + + index += 1; } } } #[allow(clippy::await_holding_refcell_ref)] - async fn respond(&self) -> Result<(), Error> { + async fn respond(&self, rx_pipe: &Pipe<'_>, _tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { - let mut buf = [0; 1580]; + { + let mut data = rx_pipe.data.lock().await; - let udp = self.0.udp.borrow(); - let udp = udp.as_ref().unwrap(); + if let Some(_chunk) = data.chunk { + // TODO: Process the incoming packed and only answer what we are being queried about - let (_len, _addr) = udp.recv(&mut buf).await?; + data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); - info!("Received UDP packet"); + self.0.notification.signal(()); + } + } - // TODO: Process the incoming packed and only answer what we are being queried about - - self.0.notification.signal(()); + rx_pipe.data_supplied_notification.wait().await; } } } impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -628,7 +689,7 @@ pub mod builtin { } fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -641,6 +702,7 @@ pub mod builtin { #[cfg(all(feature = "std", feature = "astro-dnssd"))] pub mod astro { + use core::cell::RefCell; use std::collections::HashMap; use super::Mdns; @@ -657,18 +719,18 @@ pub mod astro { } pub struct AstroMdns { - services: HashMap, + services: RefCell>, } impl AstroMdns { pub fn new() -> Result { Ok(Self { - services: HashMap::new(), + services: RefCell::new(HashMap::new()), }) } pub fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -698,7 +760,7 @@ pub mod astro { let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; - self.services.insert( + self.services.borrow_mut().insert( ServiceId { name: name.into(), service: service.into(), @@ -712,7 +774,7 @@ pub mod astro { } pub fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -725,7 +787,7 @@ pub mod astro { port, }; - if self.services.remove(&id).is_some() { + if self.services.borrow_mut().remove(&id).is_some() { info!( "Deregistering mDNS service {}/{}.{}/{}", name, service, protocol, port @@ -738,7 +800,7 @@ pub mod astro { impl Mdns for AstroMdns { fn add( - &mut self, + &self, name: &str, service: &str, protocol: &str, @@ -758,7 +820,7 @@ pub mod astro { } fn remove( - &mut self, + &self, name: &str, service: &str, protocol: &str, diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index c007ee5..80fb7b5 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -24,7 +24,7 @@ use super::status_report::{create_status_report, GeneralCode}; /* Interaction Model ID as per the Matter Spec */ pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00; -#[derive(FromPrimitive, Debug)] +#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)] pub enum OpCode { MsgCounterSyncReq = 0x00, MsgCounterSyncResp = 0x01, @@ -56,8 +56,6 @@ pub fn create_sc_status_report( status_code: SCStatusCodes, proto_data: Option<&[u8]>, ) -> Result<(), Error> { - proto_tx.reset(); - let general_code = match status_code { SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, SCStatusCodes::CloseSession => { @@ -71,6 +69,7 @@ pub fn create_sc_status_report( | SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::SessionNotFound => GeneralCode::Failure, }; + create_status_report( proto_tx, general_code, diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 2119691..523278e 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -27,7 +27,6 @@ use crate::{ utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; -use num; use super::{case::Case, pake::PaseMgr}; @@ -37,7 +36,7 @@ use super::{case::Case, pake::PaseMgr}; pub struct SecureChannel<'a> { case: Case<'a>, pase: &'a RefCell, - mdns: &'a RefCell>, + mdns: &'a MdnsMgr<'a>, } impl<'a> SecureChannel<'a> { @@ -45,7 +44,7 @@ impl<'a> SecureChannel<'a> { pub fn new< T: Borrow> + Borrow> - + Borrow>> + + Borrow> + Borrow + Borrow, >( @@ -63,7 +62,7 @@ impl<'a> SecureChannel<'a> { pub fn wrap( pase: &'a RefCell, fabric: &'a RefCell, - mdns: &'a RefCell>, + mdns: &'a MdnsMgr<'a>, rand: Rand, ) -> Self { Self { @@ -74,8 +73,8 @@ impl<'a> SecureChannel<'a> { } pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { - let proto_opcode: OpCode = - num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(ErrorCode::Invalid)?; + let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?; + ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); @@ -92,10 +91,7 @@ impl<'a> SecureChannel<'a> { .borrow_mut() .pasepake1_handler(ctx) .map(|reply| (reply, None)), - OpCode::PASEPake3 => self - .pase - .borrow_mut() - .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), + OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns), OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index b5a29a2..60920d0 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -67,7 +67,7 @@ impl PaseMgr { &mut self, verifier: VerifierData, discriminator: u16, - mdns: &mut MdnsMgr, + mdns: &MdnsMgr, ) -> Result<(), Error> { let mut buf = [0; 8]; (self.rand)(&mut buf); @@ -89,7 +89,7 @@ impl PaseMgr { Ok(()) } - pub fn disable_pase_session(&mut self, mdns: &mut MdnsMgr) -> Result<(), Error> { + pub fn disable_pase_session(&mut self, mdns: &MdnsMgr) -> Result<(), Error> { if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { mdns.unpublish_service( mdns_service_name, @@ -134,7 +134,7 @@ impl PaseMgr { pub fn pasepake3_handler( &mut self, ctx: &mut ProtoCtx, - mdns: &mut MdnsMgr, + mdns: &MdnsMgr, ) -> Result<(bool, Option), Error> { let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; diff --git a/matter/src/secure_channel/status_report.rs b/matter/src/secure_channel/status_report.rs index 2f6aed1..e837874 100644 --- a/matter/src/secure_channel/status_report.rs +++ b/matter/src/secure_channel/status_report.rs @@ -39,6 +39,7 @@ pub enum GeneralCode { PermissionDenied = 15, DataLoss = 16, } + pub fn create_status_report( proto_tx: &mut Packet, general_code: GeneralCode, diff --git a/matter/src/transport/core.rs b/matter/src/transport/core.rs index 1d02bc0..1b169ee 100644 --- a/matter/src/transport/core.rs +++ b/matter/src/transport/core.rs @@ -41,15 +41,15 @@ pub enum RecvAction<'r, 'p> { Interact(ProtoCtx<'r, 'p>), } -pub struct RecvCompletion<'r, 'a, 'p> { +pub struct RecvCompletion<'r, 'a> { transport: &'r mut Transport<'a>, - rx: Packet<'p>, - tx: Packet<'p>, + rx: Packet<'r>, + tx: Packet<'r>, state: RecvState, } -impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { - pub fn next_action(&mut self) -> Result>, Error> { +impl<'r, 'a> RecvCompletion<'r, 'a> { + pub fn next_action(&mut self) -> Result>, Error> { loop { // Polonius will remove the need for unsafe one day let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() }; @@ -60,16 +60,13 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } - fn maybe_next_action(&mut self) -> Result>>, Error> { + fn maybe_next_action(&mut self) -> Result>>, Error> { self.transport.exch_mgr.purge(); self.tx.reset(); let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { - self.transport - .exch_mgr - .get_sess_mgr() - .decode(&mut self.rx)?; + self.rx.plain_hdr_decode()?; (RecvState::OpenExchange, None) } RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) { @@ -173,16 +170,16 @@ pub enum NotifyAction<'r, 'p> { Notify(ProtoCtx<'r, 'p>), } -pub struct NotifyCompletion<'r, 'a, 'p> { +pub struct NotifyCompletion<'r, 'a> { // TODO _transport: &'r mut Transport<'a>, - _rx: &'r mut Packet<'p>, - _tx: &'r mut Packet<'p>, + _rx: Packet<'r>, + _tx: Packet<'r>, _state: NotifyState, } -impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { - pub fn next_action(&mut self) -> Result>, Error> { +impl<'r, 'a> NotifyCompletion<'r, 'a> { + pub fn next_action(&mut self) -> Result>, Error> { loop { // Polonius will remove the need for unsafe one day let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() }; @@ -193,7 +190,7 @@ impl<'r, 'a, 'p> NotifyCompletion<'r, 'a, 'p> { } } - fn maybe_next_action(&mut self) -> Result>>, Error> { + fn maybe_next_action(&mut self) -> Result>>, Error> { Ok(Some(None)) // TODO: Future } } @@ -216,7 +213,7 @@ impl<'a> Transport<'a> { } pub fn matter(&self) -> &Matter<'a> { - &self.matter + self.matter } pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { @@ -229,12 +226,12 @@ impl<'a> Transport<'a> { Ok(()) } - pub fn recv<'r, 'p>( + pub fn recv<'r>( &'r mut self, addr: Address, - rx_buf: &'p mut [u8], - tx_buf: &'p mut [u8], - ) -> RecvCompletion<'r, 'a, 'p> { + rx_buf: &'r mut [u8], + tx_buf: &'r mut [u8], + ) -> RecvCompletion<'r, 'a> { let mut rx = Packet::new_rx(rx_buf); let tx = Packet::new_tx(tx_buf); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 4910dbc..5dbb1bb 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -223,7 +223,7 @@ impl Exchange { "{} with proto id: {} opcode: {}, tlv:\n", "Sending".blue(), tx.get_proto_id(), - tx.get_proto_opcode(), + tx.get_proto_raw_opcode(), ); //print_tlv_list(tx.as_slice()); diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 18957be..a219f16 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -21,6 +21,7 @@ pub mod exchange; pub mod mrp; pub mod network; pub mod packet; +pub mod pipe; pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 72368cb..5e0cf98 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -15,10 +15,14 @@ * limitations under the License. */ -use log::error; +use log::{error, info, trace}; +use owo_colors::OwoColorize; use crate::{ error::{Error, ErrorCode}, + interaction_model::core::PROTO_ID_INTERACTION_MODEL, + secure_channel::common::PROTO_ID_SECURE_CHANNEL, + tlv, utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, }; @@ -29,6 +33,7 @@ use super::{ }; pub const MAX_RX_BUF_SIZE: usize = 1583; +pub const MAX_RX_STATUS_BUF_SIZE: usize = 100; pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; #[derive(Debug, PartialEq, Eq, Copy, Clone)] @@ -160,10 +165,22 @@ impl<'a> Packet<'a> { self.proto.proto_id = proto_id; } - pub fn get_proto_opcode(&self) -> u8 { + pub fn get_proto_opcode(&self) -> Result { + num::FromPrimitive::from_u8(self.proto.proto_opcode).ok_or(ErrorCode::Invalid.into()) + } + + pub fn get_proto_raw_opcode(&self) -> u8 { self.proto.proto_opcode } + pub fn check_proto_opcode(&self, opcode: u8) -> Result<(), Error> { + if self.proto.proto_opcode == opcode { + Ok(()) + } else { + Err(ErrorCode::Invalid.into()) + } + } + pub fn set_proto_opcode(&mut self, proto_opcode: u8) { self.proto.proto_opcode = proto_opcode; } @@ -196,6 +213,52 @@ impl<'a> Packet<'a> { } } + pub fn proto_encode( + &mut self, + peer: Address, + peer_nodeid: Option, + local_nodeid: u64, + plain_text: bool, + enc_key: Option<&[u8]>, + ) -> Result<(), Error> { + self.peer = peer; + + // Generate encrypted header + let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + self.proto.encode(&mut write_buf)?; + self.get_writebuf()?.prepend(write_buf.as_slice())?; + + // Generate plain-text header + if plain_text { + if let Some(d) = peer_nodeid { + self.plain.set_dest_u64(d); + } + } + + let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; + let mut write_buf = WriteBuf::new(&mut tmp_buf); + self.plain.encode(&mut write_buf)?; + let plain_hdr_bytes = write_buf.as_slice(); + + trace!("unencrypted packet: {:x?}", self.as_mut_slice()); + let ctr = self.plain.ctr; + if let Some(e) = enc_key { + proto_hdr::encrypt_in_place( + ctr, + local_nodeid, + plain_hdr_bytes, + self.get_writebuf()?, + e, + )?; + } + + self.get_writebuf()?.prepend(plain_hdr_bytes)?; + trace!("Full encrypted packet: {:x?}", self.as_mut_slice()); + + Ok(()) + } + pub fn is_plain_hdr_decoded(&self) -> Result { match &self.data { Direction::Rx(_, state) => match state { @@ -220,4 +283,44 @@ impl<'a> Packet<'a> { _ => Err(ErrorCode::InvalidState.into()), } } + + pub fn log(&self, operation: &str) { + match self.get_proto_id() { + PROTO_ID_SECURE_CHANNEL => { + if let Ok(opcode) = self.get_proto_opcode::() + { + info!("{} SC:{:?}: ", operation.cyan(), opcode); + } else { + info!( + "{} SC:{}??: ", + operation.cyan(), + self.get_proto_raw_opcode() + ); + } + + tlv::print_tlv_list(self.as_slice()); + } + PROTO_ID_INTERACTION_MODEL => { + if let Ok(opcode) = + self.get_proto_opcode::() + { + info!("{} IM:{:?}: ", operation.cyan(), opcode); + } else { + info!( + "{} IM:{}??: ", + operation.cyan(), + self.get_proto_raw_opcode() + ); + } + + tlv::print_tlv_list(self.as_slice()); + } + other => info!( + "{} {}??:{}??: ", + operation.cyan(), + other, + self.get_proto_raw_opcode() + ), + } + } } diff --git a/matter/src/transport/pipe.rs b/matter/src/transport/pipe.rs new file mode 100644 index 0000000..46259cc --- /dev/null +++ b/matter/src/transport/pipe.rs @@ -0,0 +1,94 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex}; + +use crate::utils::select::Notification; + +use super::network::Address; + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub struct Chunk { + pub start: usize, + pub end: usize, + pub addr: Address, +} + +pub struct PipeData<'a> { + pub buf: &'a mut [u8], + pub chunk: Option, +} + +pub struct Pipe<'a> { + pub data: Mutex>, + pub data_supplied_notification: Notification, + pub data_consumed_notification: Notification, +} + +impl<'a> Pipe<'a> { + #[inline(always)] + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + data: Mutex::new(PipeData { buf, chunk: None }), + data_supplied_notification: Notification::new(), + data_consumed_notification: Notification::new(), + } + } + + pub async fn recv(&self, buf: &mut [u8]) -> (usize, Address) { + loop { + { + let mut data = self.data.lock().await; + + if let Some(chunk) = data.chunk { + buf[..chunk.end - chunk.start] + .copy_from_slice(&data.buf[chunk.start..chunk.end]); + data.chunk = None; + + self.data_consumed_notification.signal(()); + + return (chunk.end - chunk.start, chunk.addr); + } + } + + self.data_supplied_notification.wait().await + } + } + + pub async fn send(&self, addr: Address, buf: &[u8]) { + loop { + { + let mut data = self.data.lock().await; + + if data.chunk.is_none() { + data.buf[..buf.len()].copy_from_slice(buf); + data.chunk = Some(Chunk { + start: 0, + end: buf.len(), + addr, + }); + + self.data_supplied_notification.signal(()); + + break; + } + } + + self.data_consumed_notification.wait().await + } + } +} diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 1c2e936..c421244 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -22,12 +22,8 @@ use core::fmt; use core::ops::{Deref, DerefMut}; use core::time::Duration; -use crate::{ - error::*, - transport::{plain_hdr, proto_hdr}, - utils::writebuf::WriteBuf, -}; -use log::{info, trace}; +use crate::{error::*, transport::plain_hdr}; +use log::info; use super::dedup::RxCtrState; use super::{network::Address, packet::Packet}; @@ -255,44 +251,16 @@ impl Session { Ok(()) } - // TODO: Most of this can now be moved into the 'Packet' module - fn do_send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { + fn send(&mut self, epoch: Epoch, tx: &mut Packet) -> Result<(), Error> { self.last_use = epoch(); - tx.peer = self.peer_addr; - // Generate encrypted header - let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf); - tx.proto.encode(&mut write_buf)?; - tx.get_writebuf()?.prepend(write_buf.as_slice())?; - - // Generate plain-text header - if self.mode == SessionMode::PlainText { - if let Some(d) = self.peer_nodeid { - tx.plain.set_dest_u64(d); - } - } - let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()]; - let mut write_buf = WriteBuf::new(&mut tmp_buf); - tx.plain.encode(&mut write_buf)?; - let plain_hdr_bytes = write_buf.as_slice(); - - trace!("unencrypted packet: {:x?}", tx.as_mut_slice()); - let ctr = tx.plain.ctr; - let enc_key = self.get_enc_key(); - if let Some(e) = enc_key { - proto_hdr::encrypt_in_place( - ctr, - self.local_nodeid, - plain_hdr_bytes, - tx.get_writebuf()?, - e, - )?; - } - - tx.get_writebuf()?.prepend(plain_hdr_bytes)?; - trace!("Full encrypted packet: {:x?}", tx.as_mut_slice()); - Ok(()) + tx.proto_encode( + self.peer_addr, + self.peer_nodeid, + self.local_nodeid, + self.mode == SessionMode::PlainText, + self.get_enc_key(), + ) } fn rand_msg_ctr(rand: Rand) -> u32 { @@ -493,32 +461,11 @@ impl SessionMgr { } } - pub fn decode(&mut self, rx: &mut Packet) -> Result<(), Error> { - // let network = self.network.as_ref().ok_or(ErrorCode::NoNetworkInterface)?; - - // let (len, src) = network.recv(rx.as_borrow_slice()).await?; - // rx.get_parsebuf()?.set_len(len); - // rx.peer = src; - - // info!("{} from src: {}", "Received".blue(), src); - // trace!("payload: {:x?}", rx.as_borrow_slice()); - - // Read unencrypted packet header - rx.plain_hdr_decode() - } - pub fn send(&mut self, sess_idx: usize, tx: &mut Packet) -> Result<(), Error> { self.sessions[sess_idx] .as_mut() .ok_or(ErrorCode::NoSession)? - .do_send(self.epoch, tx)?; - - // let network = self.network.as_ref().ok_or(Error::NoNetworkInterface)?; - // let peer = proto_tx.peer; - // network.send(proto_tx.as_borrow_slice(), peer).await?; - // info!("Message Sent to {}", peer); - - Ok(()) + .send(self.epoch, tx) } pub fn get_session_handle(&mut self, sess_idx: usize) -> SessionHandle { diff --git a/matter/src/utils/select.rs b/matter/src/utils/select.rs index 2b5d21e..a63c10b 100644 --- a/matter/src/utils/select.rs +++ b/matter/src/utils/select.rs @@ -1,4 +1,7 @@ use embassy_futures::select::{Either, Either3, Either4}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; + +pub type Notification = embassy_sync::signal::Signal; pub trait EitherUnwrap { fn unwrap(self) -> T; diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 2f24c97..d091dfb 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -38,6 +38,10 @@ impl<'a> WriteBuf<'a> { } } + pub fn get_start(&self) -> usize { + self.start + } + pub fn get_tail(&self) -> usize { self.end } diff --git a/matter/tests/common/echo_cluster.rs b/matter/tests/common/echo_cluster.rs index e5caca7..5e43e18 100644 --- a/matter/tests/common/echo_cluster.rs +++ b/matter/tests/common/echo_cluster.rs @@ -24,8 +24,8 @@ use matter::{ attribute_enum, command_enum, data_model::objects::{ Access, AttrData, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, - Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, Quality, - ATTRIBUTE_LIST, FEATURE_MAP, + Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, NonBlockingHandler, + Quality, ATTRIBUTE_LIST, FEATURE_MAP, }, error::{Error, ErrorCode}, interaction_model::{ @@ -286,3 +286,5 @@ impl Handler for EchoCluster { EchoCluster::invoke(self, transaction, cmd, data, encoder) } } + +impl NonBlockingHandler for EchoCluster {} diff --git a/matter/tests/common/im_engine.rs b/matter/tests/common/im_engine.rs index 70b2aca..13da8cd 100644 --- a/matter/tests/common/im_engine.rs +++ b/matter/tests/common/im_engine.rs @@ -110,7 +110,7 @@ pub fn matter(mdns: &mut dyn Mdns) -> Matter<'_> { #[cfg(not(feature = "std"))] use matter::utils::epoch::dummy_epoch as epoch; - Matter::new(BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) + Matter::new(&BASIC_INFO, &DummyDevAtt, mdns, epoch, dummy_rand, 5540) } /// An Interaction Model Engine to facilitate easy testing @@ -236,7 +236,7 @@ impl<'a> ImEngine<'a> { self.im.handle(&mut ctx).unwrap(); let out_data_len = ctx.tx.as_slice().len(); data_out[..out_data_len].copy_from_slice(ctx.tx.as_slice()); - let response = ctx.tx.get_proto_opcode(); + let response = ctx.tx.get_proto_raw_opcode(); (response, &data_out[..out_data_len]) } }